From ae93d71217c080c73c17a7ee5fd04f5d5dc65d93 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 16 Sep 2025 22:38:45 +0200 Subject: [PATCH 1/6] fixed simple annoying bugs --- README.md | 6 +++--- models/qwen/qwen_main.py | 1 + models/wan/any2video.py | 2 +- wgp.py | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 2411cea..8c5e436 100644 --- a/README.md +++ b/README.md @@ -24,12 +24,12 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models - The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**) -- **First Frame / Last Frame for Vace** : Vace model are so powerful that they could do *First frame / Last frame* since day one using the *Injected Frames* feature. However this required to compute by hand the locations of each end frame since this feature expects frames positions. I made it easier to compute these locations by using the "L" alias : +- **First Frame / Last Frame for Vace** : Vace models are so powerful that they could do *First frame / Last frame* since day one using the *Injected Frames* feature. However this required to compute by hand the locations of each end frame since this feature expects frames positions. I made it easier to compute these locations by using the "L" alias : For a video Gen from scratch *"1 L L L"* means the 4 Injected Frames will be injected like this: frame no 1 at the first position, the next frame at the end of the first window, then the following frame at the end of the next window, and so on .... -If you *Continue a Video* , you just need *"L L L"* since the the first frame is the last frame of the *Source Video*. In any case remember that numeral frames positions (like "1") are aligned by default to the beginning of the source window, so low values such as 1 will be considered in the past unless you change this behaviour in *Sliding Window Tab/ Control Video, Injected Frames aligment*. +If you *Continue a Video* , you just need *"L L L"* since the first frame is the last frame of the *Source Video*. In any case remember that numeral frames positions (like "1") are aligned by default to the beginning of the source window, so low values such as 1 will be considered in the past unless you change this behaviour in *Sliding Window Tab/ Control Video, Injected Frames aligment*. -- **Qwen Inpainting** exist now in two versions: the original version of the previous release and a Lora based version. Each version has its pros and cons. For instance the Lora version supports also **Outpainting** ! However it tends to change slightly the original image even outside the outpainted area. +- **Qwen Edit Inpainting** exists now in two versions: the original version of the previous release and a Lora based version. Each version has its pros and cons. For instance the Lora version supports also **Outpainting** ! However it tends to change slightly the original image even outside the outpainted area. - **Better Lipsync with all the Audio to Video models**: you probably noticed that *Multitalk*, *InfiniteTalk* or *Hunyuan Avatar* had so so lipsync when the audio provided contained some background music. The problem should be solved now thanks to an automated background music removal all done by IA. Don't worry you will still hear the music as it is added back in the generated Video. diff --git a/models/qwen/qwen_main.py b/models/qwen/qwen_main.py index da84e0d..cec0b8e 100644 --- a/models/qwen/qwen_main.py +++ b/models/qwen/qwen_main.py @@ -217,6 +217,7 @@ class model_factory(): def get_loras_transformer(self, get_model_recursive_prop, model_type, model_mode, **kwargs): if model_mode == 0: return [], [] preloadURLs = get_model_recursive_prop(model_type, "preload_URLs") + if len(preloadURLs) == 0: return [], [] return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1] diff --git a/models/wan/any2video.py b/models/wan/any2video.py index dde1a65..3e85f79 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -146,7 +146,7 @@ class WanAny2V: from mmgp.safetensors2 import torch_load_file else: if self.transformer_switch: - if 0 in submodel_no_list[2:] and 1 in submodel_no_list: + if 0 in submodel_no_list[2:] and 1 in submodel_no_list[2:]: raise Exception("Shared and non shared modules at the same time across multipe models is not supported") if 0 in submodel_no_list[2:]: diff --git a/wgp.py b/wgp.py index de8e2f6..ee26376 100644 --- a/wgp.py +++ b/wgp.py @@ -62,7 +62,7 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.6" +WanGP_version = "8.61" settings_version = 2.35 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None From fc615ffb3cb063bd4579ee98b2f81e7ca6633971 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 16 Sep 2025 23:01:54 +0200 Subject: [PATCH 2/6] fixed simple annoying bugs --- wgp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wgp.py b/wgp.py index ee26376..ee6b5f6 100644 --- a/wgp.py +++ b/wgp.py @@ -356,7 +356,7 @@ def process_prompt_and_add_tasks(state, model_choice): outpainting_dims = get_outpainting_dims(video_guide_outpainting) - if server_config.get("fit_canvas", 0) == 2 and outpainting_dims is not None: + if server_config.get("fit_canvas", 0) == 2 and outpainting_dims is not None and any_letters(video_prompt_type, "VKF"): gr.Info("Output Resolution Cropping will be not used for this Generation as it is not compatible with Video Outpainting") if len(loras_multipliers) > 0: From 84010bd861545fd2d8365aaece6b4017d40da9aa Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Mon, 22 Sep 2025 17:11:25 +0200 Subject: [PATCH 3/6] commit in case there is an unrecoverable code hemorragy --- configs/animate.json | 15 + configs/lucy_edit.json | 14 + defaults/animate.json | 13 + defaults/lucy_edit.json | 18 ++ defaults/ti2v_2_2_fastwan.json | 1 + models/qwen/pipeline_qwenimage.py | 144 +++++---- models/wan/any2video.py | 191 +++++------ models/wan/modules/model.py | 53 ++- models/wan/wan_handler.py | 81 +++-- shared/convert/convert_diffusers_to_flux.py | 342 ++++++++++++++++++++ shared/inpainting/__init__.py | 0 shared/inpainting/lanpaint.py | 240 ++++++++++++++ shared/inpainting/utils.py | 301 +++++++++++++++++ shared/utils/audio_video.py | 3 + shared/utils/download.py | 110 +++++++ shared/utils/utils.py | 246 ++++++-------- wgp.py | 228 +++++++++---- 17 files changed, 1620 insertions(+), 380 deletions(-) create mode 100644 configs/animate.json create mode 100644 configs/lucy_edit.json create mode 100644 defaults/animate.json create mode 100644 defaults/lucy_edit.json create mode 100644 shared/convert/convert_diffusers_to_flux.py create mode 100644 shared/inpainting/__init__.py create mode 100644 shared/inpainting/lanpaint.py create mode 100644 shared/inpainting/utils.py create mode 100644 shared/utils/download.py diff --git a/configs/animate.json b/configs/animate.json new file mode 100644 index 0000000..7e98ca9 --- /dev/null +++ b/configs/animate.json @@ -0,0 +1,15 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.30.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 36, + "model_type": "i2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "motion_encoder_dim": 512 +} \ No newline at end of file diff --git a/configs/lucy_edit.json b/configs/lucy_edit.json new file mode 100644 index 0000000..4983ced --- /dev/null +++ b/configs/lucy_edit.json @@ -0,0 +1,14 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.33.0", + "dim": 3072, + "eps": 1e-06, + "ffn_dim": 14336, + "freq_dim": 256, + "in_dim": 96, + "model_type": "ti2v2_2", + "num_heads": 24, + "num_layers": 30, + "out_dim": 48, + "text_len": 512 +} diff --git a/defaults/animate.json b/defaults/animate.json new file mode 100644 index 0000000..fee26d7 --- /dev/null +++ b/defaults/animate.json @@ -0,0 +1,13 @@ +{ + "model": { + "name": "Wan2.2 Animate", + "architecture": "animate", + "description": "Wan-Animate takes a video and a character image as input, and generates a video in either 'animation' or 'replacement' mode.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_fp16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_bf16_int8.safetensors" + ], + "group": "wan2_2" + } +} \ No newline at end of file diff --git a/defaults/lucy_edit.json b/defaults/lucy_edit.json new file mode 100644 index 0000000..6344dff --- /dev/null +++ b/defaults/lucy_edit.json @@ -0,0 +1,18 @@ +{ + "model": { + "name": "Wan2.2 Lucy Edit 5B", + "architecture": "lucy_edit", + "description": "Lucy Edit Dev is a video editing model that performs instruction-guided edits on videos using free-text prompts \u2014 it supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mfp16_int8.safetensors" + ], + "group": "wan2_2" + }, + "video_length": 81, + "guidance_scale": 5, + "flow_shift": 5, + "num_inference_steps": 30, + "resolution": "1280x720" +} \ No newline at end of file diff --git a/defaults/ti2v_2_2_fastwan.json b/defaults/ti2v_2_2_fastwan.json index 064c2b4..fa69f82 100644 --- a/defaults/ti2v_2_2_fastwan.json +++ b/defaults/ti2v_2_2_fastwan.json @@ -7,6 +7,7 @@ "loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"], "group": "wan2_2" }, + "prompt" : "Put the person into a clown outfit.", "video_length": 121, "guidance_scale": 1, "flow_shift": 3, diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py index be982aa..1472886 100644 --- a/models/qwen/pipeline_qwenimage.py +++ b/models/qwen/pipeline_qwenimage.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from mmgp import offload import inspect from typing import Any, Callable, Dict, List, Optional, Union @@ -387,7 +386,8 @@ class QwenImagePipeline(): #DiffusionPipeline return latent_image_ids.to(device=device, dtype=dtype) @staticmethod - def _pack_latents(latents, batch_size, num_channels_latents, height, width): + def _pack_latents(latents): + batch_size, num_channels_latents, _, height, width = latents.shape latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) @@ -479,7 +479,7 @@ class QwenImagePipeline(): #DiffusionPipeline height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) - shape = (batch_size, 1, num_channels_latents, height, width) + shape = (batch_size, num_channels_latents, 1, height, width) image_latents = None if image is not None: @@ -499,10 +499,7 @@ class QwenImagePipeline(): #DiffusionPipeline else: image_latents = torch.cat([image_latents], dim=0) - image_latent_height, image_latent_width = image_latents.shape[3:] - image_latents = self._pack_latents( - image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width - ) + image_latents = self._pack_latents(image_latents) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -511,7 +508,7 @@ class QwenImagePipeline(): #DiffusionPipeline ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents) else: latents = latents.to(device=device, dtype=dtype) @@ -713,11 +710,12 @@ class QwenImagePipeline(): #DiffusionPipeline image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of) # image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of) height, width = image_height, image_width - image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 16, height // 16), resample=Image.Resampling.LANCZOS)) + image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 8, height // 8), resample=Image.Resampling.LANCZOS)) image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1] - image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0) + image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0) # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png") - image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device) + image_mask_latents = image_mask_latents.to(device).unsqueeze(0).unsqueeze(0).repeat(1,16,1,1,1) + image_mask_latents = self._pack_latents(image_mask_latents) prompt_image = image if image.size != (image_width, image_height): @@ -822,6 +820,7 @@ class QwenImagePipeline(): #DiffusionPipeline negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None ) morph, first_step = False, 0 + lanpaint_proc = None if image_mask_latents is not None: randn = torch.randn_like(original_image_latents) if denoising_strength < 1.: @@ -833,7 +832,8 @@ class QwenImagePipeline(): #DiffusionPipeline timesteps = timesteps[first_step:] self.scheduler.timesteps = timesteps self.scheduler.sigmas= self.scheduler.sigmas[first_step:] - + # from shared.inpainting.lanpaint import LanPaint + # lanpaint_proc = LanPaint() # 6. Denoising loop self.scheduler.set_begin_index(0) updated_num_steps= len(timesteps) @@ -847,48 +847,52 @@ class QwenImagePipeline(): #DiffusionPipeline offload.set_step_no_for_lora(self.transformer, first_step + i) if self.interrupt: continue + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph: latent_noise_factor = t/1000 latents = original_image_latents * (1.0 - latent_noise_factor) + latents * latent_noise_factor - self._current_timestep = t - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) - latent_model_input = latents - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1) + latents_dtype = latents.dtype - if do_true_cfg and joint_pass: - noise_pred, neg_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask], - encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds], - img_shapes=img_shapes, - txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens], - attention_kwargs=self.attention_kwargs, - **kwargs - ) - if noise_pred == None: return None - noise_pred = noise_pred[:, : latents.size(1)] - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - else: - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask_list=[prompt_embeds_mask], - encoder_hidden_states_list=[prompt_embeds], - img_shapes=img_shapes, - txt_seq_lens_list=[txt_seq_lens], - attention_kwargs=self.attention_kwargs, - **kwargs - )[0] - if noise_pred == None: return None - noise_pred = noise_pred[:, : latents.size(1)] + # latent_model_input = latents + def denoise(latent_model_input, true_cfg_scale): + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + do_true_cfg = true_cfg_scale > 1 + if do_true_cfg and joint_pass: + noise_pred, neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, #!!!! + encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask], + encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds], + img_shapes=img_shapes, + txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens], + attention_kwargs=self.attention_kwargs, + **kwargs + ) + if noise_pred == None: return None, None + noise_pred = noise_pred[:, : latents.size(1)] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + else: + neg_noise_pred = None + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask_list=[prompt_embeds_mask], + encoder_hidden_states_list=[prompt_embeds], + img_shapes=img_shapes, + txt_seq_lens_list=[txt_seq_lens], + attention_kwargs=self.attention_kwargs, + **kwargs + )[0] + if noise_pred == None: return None, None + noise_pred = noise_pred[:, : latents.size(1)] if do_true_cfg: neg_noise_pred = self.transformer( @@ -902,27 +906,43 @@ class QwenImagePipeline(): #DiffusionPipeline attention_kwargs=self.attention_kwargs, **kwargs )[0] - if neg_noise_pred == None: return None + if neg_noise_pred == None: return None, None neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + return noise_pred, neg_noise_pred + def cfg_predictions( noise_pred, neg_noise_pred, guidance, t): + if do_true_cfg: + comb_pred = neg_noise_pred + guidance * (noise_pred - neg_noise_pred) + if comb_pred == None: return None - if do_true_cfg: - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - if comb_pred == None: return None + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - neg_noise_pred = None + return noise_pred + + + if lanpaint_proc is not None and i<=3: + latents = lanpaint_proc(denoise, cfg_predictions, true_cfg_scale, 1., latents, original_image_latents, randn, t/1000, image_mask_latents, height=height , width= width, vae_scale_factor= 8) + if latents is None: return None + + noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale) + if noise_pred == None: return None + noise_pred = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t) + neg_noise_pred = None # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + noise_pred = None + if image_mask_latents is not None: - next_t = timesteps[i+1] if i0: + msk[:, :nb_frames_unchanged] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1,2)[0] + return msk + def generate(self, input_prompt, input_frames= None, input_masks = None, - input_ref_images = None, + input_ref_images = None, + input_ref_masks = None, + input_faces = None, input_video = None, image_start = None, image_end = None, @@ -541,14 +506,18 @@ class WanAny2V: infinitetalk = model_type in ["infinitetalk"] standin = model_type in ["standin", "vace_standin_14B"] recam = model_type in ["recam_1.3B"] - ti2v = model_type in ["ti2v_2_2"] + ti2v = model_type in ["ti2v_2_2", "lucy_edit"] + lucy_edit= model_type in ["lucy_edit"] + animate= model_type in ["animate"] start_step_no = 0 ref_images_count = 0 trim_frames = 0 - extended_overlapped_latents = None + extended_overlapped_latents = clip_image_start = clip_image_end = None no_noise_latents_injection = infinitetalk timestep_injection = False lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1 + extended_input_dim = 0 + ref_images_before = False # image2video if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]: any_end_frame = False @@ -598,17 +567,7 @@ class WanAny2V: if image_end is not None: img_end_frame = image_end.unsqueeze(1).to(self.device) - - if hasattr(self, "clip"): - clip_image_size = self.clip.model.image_size - image_start = resize_lanczos(image_start, clip_image_size, clip_image_size) - image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) if image_end is not None else image_start - if model_type == "flf2v_720p": - clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end is not None else image_start[:, None, :, :]]) - else: - clip_context = self.clip.visual([image_start[:, None, :, :]]) - else: - clip_context = None + clip_image_start, clip_image_end = image_start, image_end if any_end_frame: enc= torch.concat([ @@ -647,21 +606,62 @@ class WanAny2V: if infinitetalk: lat_y = self.vae.encode([input_video], VAE_tile_size)[0] extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0) - # if control_pre_frames_count != pre_frames_count: lat_y = input_video = None kwargs.update({ 'y': y}) - if not clip_context is None: - kwargs.update({'clip_fea': clip_context}) - # Recam Master - if recam: - target_camera = model_mode - height,width = input_frames.shape[-2:] - input_frames = input_frames.to(dtype=self.dtype , device=self.device) - source_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device) + # Animate + if animate: + pose_pixels = input_frames * input_masks + input_masks = 1. - input_masks + pose_pixels -= input_masks + save_video(pose_pixels, "pose.mp4") + pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0) + input_frames = input_frames * input_masks + if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames + if prefix_frames_count > 0: + input_frames[:, :prefix_frames_count] = input_video + input_masks[:, :prefix_frames_count] = 1 + save_video(input_frames, "input_frames.mp4") + save_video(input_masks, "input_masks.mp4", value_range=(0,1)) + lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] + msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device) + msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device) + msk = torch.concat([msk_ref, msk_control], dim=1) + clip_image_start = image_ref = convert_image_to_tensor(input_ref_images[0]).to(self.device) + lat_y = torch.concat(self.vae.encode([image_ref.unsqueeze(1).to(self.device), input_frames.to(self.device)], VAE_tile_size), dim=1) + y = torch.concat([msk, lat_y]) + kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)}) + lat_y = msk = msk_control = msk_ref = pose_pixels = None + ref_images_before = True + ref_images_count = 1 + lat_frames = int((input_frames.shape[1] - 1) // self.vae_stride[0]) + 1 + + # Clip image + if hasattr(self, "clip") and clip_image_start is not None: + clip_image_size = self.clip.model.image_size + clip_image_start = resize_lanczos(clip_image_start, clip_image_size, clip_image_size) + clip_image_end = resize_lanczos(clip_image_end, clip_image_size, clip_image_size) if clip_image_end is not None else clip_image_start + if model_type == "flf2v_720p": + clip_context = self.clip.visual([clip_image_start[:, None, :, :], clip_image_end[:, None, :, :] if clip_image_end is not None else clip_image_start[:, None, :, :]]) + else: + clip_context = self.clip.visual([clip_image_start[:, None, :, :]]) + clip_image_start = clip_image_end = None + kwargs.update({'clip_fea': clip_context}) + + # Recam Master & Lucy Edit + if recam or lucy_edit: + frame_num, height,width = input_frames.shape[-3:] + lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1 + frame_num = (lat_frames -1) * self.vae_stride[0] + 1 + input_frames = input_frames[:, :frame_num].to(dtype=self.dtype , device=self.device) + extended_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device) + extended_input_dim = 2 if recam else 1 del input_frames + + if recam: # Process target camera (recammaster) + target_camera = model_mode from shared.utils.cammmaster_tools import get_camera_embedding cam_emb = get_camera_embedding(target_camera) cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) @@ -715,6 +715,8 @@ class WanAny2V: height, width = input_video.shape[-2:] source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0) timestep_injection = True + if extended_input_dim > 0: + extended_latents[:, :, :source_latents.shape[2]] = source_latents # Vace if vace : @@ -722,6 +724,7 @@ class WanAny2V: input_frames = [u.to(self.device) for u in input_frames] input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] input_masks = [u.to(self.device) for u in input_masks] + ref_images_before = True if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask] z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents ) m0 = self.vace_encode_masks(input_masks, input_ref_images) @@ -771,9 +774,9 @@ class WanAny2V: expand_shape = [batch_size] + [-1] * len(target_shape) # Ropes - if target_camera != None: + if extended_input_dim>=2: shape = list(target_shape[1:]) - shape[0] *= 2 + shape[extended_input_dim-2] *= 2 freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) else: freqs = get_rotary_pos_embed(target_shape[1:], enable_RIFLEx= enable_RIFLEx) @@ -901,8 +904,8 @@ class WanAny2V: for zz in z: zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor - if target_camera != None: - latent_model_input = torch.cat([latents, source_latents.expand(*expand_shape)], dim=2) + if extended_input_dim > 0: + latent_model_input = torch.cat([latents, extended_latents.expand(*expand_shape)], dim=extended_input_dim) else: latent_model_input = latents @@ -1030,7 +1033,7 @@ class WanAny2V: if callback is not None: latents_preview = latents - if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] + if ref_images_before and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames] if image_outputs: latents_preview= latents_preview[:, :,:1] if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2) @@ -1041,7 +1044,7 @@ class WanAny2V: if timestep_injection: latents[:, :, :source_latents.shape[2]] = source_latents - if vace and ref_images_count > 0: latents = latents[:, :, ref_images_count:] + if ref_images_before and ref_images_count > 0: latents = latents[:, :, ref_images_count:] if trim_frames > 0: latents= latents[:, :,:-trim_frames] if return_latent_slice != None: latent_slice = latents[:, :, return_latent_slice].clone() @@ -1078,4 +1081,12 @@ class WanAny2V: delattr(model, "vace_blocks") + def adapt_animate_model(self, model): + modules_dict= { k: m for k, m in model.named_modules()} + for animate_layer in range(8): + module = modules_dict[f"face_adapter.fuser_blocks.{animate_layer}"] + model_layer = animate_layer * 5 + target = modules_dict[f"blocks.{model_layer}"] + setattr(target, "face_adapter_fuser_blocks", module ) + delattr(model, "face_adapter") diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index d3dc783..c98087b 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -16,6 +16,9 @@ from mmgp.offload import get_cache, clear_caches from shared.attention import pay_attention from torch.backends.cuda import sdp_kernel from ..multitalk.multitalk_utils import get_attn_map_with_target +from ..animate.motion_encoder import Generator +from ..animate.face_blocks import FaceAdapter, FaceEncoder +from ..animate.model_animate import after_patch_embedding __all__ = ['WanModel'] @@ -499,6 +502,7 @@ class WanAttentionBlock(nn.Module): multitalk_masks=None, ref_images_count=0, standin_phase=-1, + motion_vec = None, ): r""" Args: @@ -616,6 +620,10 @@ class WanAttentionBlock(nn.Module): x.add_(hint) else: x.add_(hint, alpha= scale) + + if motion_vec is not None and self.block_no % 5 == 0: + x += self.face_adapter_fuser_blocks(x.to(self.face_adapter_fuser_blocks.linear1_kv.weight.dtype), motion_vec, None, False) + return x class AudioProjModel(ModelMixin, ConfigMixin): @@ -898,6 +906,7 @@ class WanModel(ModelMixin, ConfigMixin): norm_input_visual=True, norm_output_audio=True, standin= False, + motion_encoder_dim=0, ): super().__init__() @@ -922,14 +931,15 @@ class WanModel(ModelMixin, ConfigMixin): self.flag_causal_attention = False self.block_mask = None self.inject_sample_info = inject_sample_info - + self.motion_encoder_dim = motion_encoder_dim self.norm_output_audio = norm_output_audio self.audio_window = audio_window self.intermediate_dim = intermediate_dim self.vae_scale = vae_scale multitalk = multitalk_output_dim > 0 - self.multitalk = multitalk + self.multitalk = multitalk + animate = motion_encoder_dim > 0 # embeddings self.patch_embedding = nn.Conv3d( @@ -1027,6 +1037,25 @@ class WanModel(ModelMixin, ConfigMixin): block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128) block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128) + if animate: + self.pose_patch_embedding = nn.Conv3d( + 16, dim, kernel_size=patch_size, stride=patch_size + ) + + self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20) + self.face_adapter = FaceAdapter( + heads_num=self.num_heads, + hidden_dim=self.dim, + num_adapter_layers=self.num_layers // 5, + ) + + self.face_encoder = FaceEncoder( + in_dim=motion_encoder_dim, + hidden_dim=self.dim, + num_heads=4, + ) + + def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32): layer_list = [self.head, self.head.head, self.patch_embedding] target_dype= dtype @@ -1208,6 +1237,9 @@ class WanModel(ModelMixin, ConfigMixin): ref_images_count = 0, standin_freqs = None, standin_ref = None, + pose_latents=None, + face_pixel_values=None, + ): # patch_dtype = self.patch_embedding.weight.dtype modulation_dtype = self.time_projection[1].weight.dtype @@ -1240,9 +1272,18 @@ class WanModel(ModelMixin, ConfigMixin): if bz > 1: y = y.expand(bz, -1, -1, -1, -1) x = torch.cat([x, y], dim=1) # embeddings - # x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype) x = self.patch_embedding(x).to(modulation_dtype) grid_sizes = x.shape[2:] + x_list[i] = x + y = None + + motion_vec_list = [] + for i, x in enumerate(x_list): + # animate embeddings + motion_vec = None + if pose_latents is not None: + x, motion_vec = after_patch_embedding(self, x, pose_latents, face_pixel_values) + motion_vec_list.append(motion_vec) if chipmunk: x = x.unsqueeze(-1) x_og_shape = x.shape @@ -1250,7 +1291,7 @@ class WanModel(ModelMixin, ConfigMixin): else: x = x.flatten(2).transpose(1, 2) x_list[i] = x - x, y = None, None + x = None block_mask = None @@ -1450,9 +1491,9 @@ class WanModel(ModelMixin, ConfigMixin): continue x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs) else: - for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc)): + for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list)): if should_calc: - x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, **kwargs) + x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec,**kwargs) del x context = hints = audio_embedding = None diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index 6586176..ff1570f 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -3,10 +3,10 @@ import numpy as np import gradio as gr def test_class_i2v(base_model_type): - return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk" ] + return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "animate" ] def text_oneframe_overlap(base_model_type): - return test_class_i2v(base_model_type) and not test_multitalk(base_model_type) + return test_class_i2v(base_model_type) and not (test_multitalk(base_model_type) or base_model_type in ["animate"]) or test_wan_5B(base_model_type) def test_class_1_3B(base_model_type): return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"] @@ -17,6 +17,8 @@ def test_multitalk(base_model_type): def test_standin(base_model_type): return base_model_type in ["standin", "vace_standin_14B"] +def test_wan_5B(base_model_type): + return base_model_type in ["ti2v_2_2", "lucy_edit"] class family_handler(): @staticmethod @@ -36,7 +38,7 @@ class family_handler(): def_mag_ratios = [1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181] elif base_model_type in ["i2v_2_2"]: def_mag_ratios = [0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902] - elif base_model_type in ["ti2v_2_2"]: + elif test_wan_5B(base_model_type): if inputs.get("image_start", None) is not None and inputs.get("video_source", None) is not None : # t2v def_mag_ratios = [0.99505, 0.99389, 0.99441, 0.9957, 0.99558, 0.99551, 0.99499, 0.9945, 0.99534, 0.99548, 0.99468, 0.9946, 0.99463, 0.99458, 0.9946, 0.99453, 0.99408, 0.99404, 0.9945, 0.99441, 0.99409, 0.99398, 0.99403, 0.99397, 0.99382, 0.99377, 0.99349, 0.99343, 0.99377, 0.99378, 0.9933, 0.99328, 0.99303, 0.99301, 0.99217, 0.99216, 0.992, 0.99201, 0.99201, 0.99202, 0.99133, 0.99132, 0.99112, 0.9911, 0.99155, 0.99155, 0.98958, 0.98957, 0.98959, 0.98958, 0.98838, 0.98835, 0.98826, 0.98825, 0.9883, 0.98828, 0.98711, 0.98709, 0.98562, 0.98561, 0.98511, 0.9851, 0.98414, 0.98412, 0.98284, 0.98282, 0.98104, 0.98101, 0.97981, 0.97979, 0.97849, 0.97849, 0.97557, 0.97554, 0.97398, 0.97395, 0.97171, 0.97166, 0.96917, 0.96913, 0.96511, 0.96507, 0.96263, 0.96257, 0.95839, 0.95835, 0.95483, 0.95475, 0.94942, 0.94936, 0.9468, 0.94678, 0.94583, 0.94594, 0.94843, 0.94872, 0.96949, 0.97015] else: # i2v @@ -83,11 +85,13 @@ class family_handler(): vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"] extra_model_def["vace_class"] = vace_class - if test_multitalk(base_model_type): + if base_model_type in ["animate"]: + fps = 30 + elif test_multitalk(base_model_type): fps = 25 elif base_model_type in ["fantasy"]: fps = 23 - elif base_model_type in ["ti2v_2_2"]: + elif test_wan_5B(base_model_type): fps = 24 else: fps = 16 @@ -100,14 +104,14 @@ class family_handler(): extra_model_def.update({ "frames_minimum" : frames_minimum, "frames_steps" : frames_steps, - "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy"] or test_class_i2v(base_model_type) or vace_class, #"ti2v_2_2", + "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy", "animate"] or test_class_i2v(base_model_type) or test_wan_5B(base_model_type) or vace_class, #"ti2v_2_2", "multiple_submodels" : multiple_submodels, "guidance_max_phases" : 3, "skip_layer_guidance" : True, "cfg_zero" : True, "cfg_star" : True, "adaptive_projected_guidance" : True, - "tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels), + "tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels), "mag_cache" : True, "keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"], "convert_image_guide_to_video" : True, @@ -146,6 +150,34 @@ class family_handler(): } # extra_model_def["at_least_one_image_ref_needed"] = True + if base_model_type in ["lucy_edit"]: + extra_model_def["keep_frames_video_guide_not_supported"] = True + extra_model_def["guide_preprocessing"] = { + "selection": ["UV"], + "labels" : { "UV": "Control Video"}, + "visible": False, + } + + if base_model_type in ["animate"]: + extra_model_def["guide_custom_choices"] = { + "choices":[ + ("Animate Person in Reference Image using Motion of Person in Control Video", "PVBXAKI"), + ("Replace Person in Control Video Person in Reference Image", "PVBAI"), + ], + "default": "KI", + "letters_filter": "PVBXAKI", + "label": "Type of Process", + "show_label" : False, + } + extra_model_def["video_guide_outpainting"] = [0,1] + extra_model_def["keep_frames_video_guide_not_supported"] = True + extra_model_def["extract_guide_from_window_start"] = True + extra_model_def["forced_guide_mask_inputs"] = True + extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)" + extra_model_def["background_ref_outpainted"] = False + + + if vace_class: extra_model_def["guide_preprocessing"] = { "selection": ["", "UV", "PV", "DV", "SV", "LV", "CV", "MV", "V", "PDV", "PSV", "PLV" , "DSV", "DLV", "SLV"], @@ -157,16 +189,19 @@ class family_handler(): extra_model_def["image_ref_choices"] = { "choices": [("None", ""), - ("Inject only People / Objects", "I"), - ("Inject Landscape and then People / Objects", "KI"), - ("Inject Frames and then People / Objects", "FI"), + ("People / Objects", "I"), + ("Landscape followed by People / Objects (if any)", "KI"), + ("Positioned Frames followed by People / Objects (if any)", "FI"), ], "letters_filter": "KFI", } extra_model_def["lock_image_refs_ratios"] = True - extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or positioned Frames" + extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames" extra_model_def["video_guide_outpainting"] = [0,1] + extra_model_def["pad_guide_video"] = True + extra_model_def["guide_inpaint_color"] = 127.5 + extra_model_def["forced_guide_mask_inputs"] = True if base_model_type in ["standin"]: extra_model_def["lock_image_refs_ratios"] = True @@ -209,10 +244,12 @@ class family_handler(): "visible" : False, } - if vace_class or base_model_type in ["infinitetalk"]: + if vace_class or base_model_type in ["infinitetalk", "animate"]: image_prompt_types_allowed = "TVL" elif base_model_type in ["ti2v_2_2"]: image_prompt_types_allowed = "TSVL" + elif base_model_type in ["lucy_edit"]: + image_prompt_types_allowed = "TVL" elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]: image_prompt_types_allowed = "SVL" elif i2v: @@ -234,8 +271,8 @@ class family_handler(): def query_supported_types(): return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B", "t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", - "recam_1.3B", - "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] + "recam_1.3B", "animate", + "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "lucy_edit", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] @staticmethod @@ -265,11 +302,12 @@ class family_handler(): @staticmethod def get_vae_block_size(base_model_type): - return 32 if base_model_type == "ti2v_2_2" else 16 + return 32 if test_wan_5B(base_model_type) else 16 @staticmethod def get_rgb_factors(base_model_type ): from shared.RGB_factors import get_rgb_factors + if test_wan_5B(base_model_type): base_model_type = "ti2v_2_2" latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type) return latent_rgb_factors, latent_rgb_factors_bias @@ -283,7 +321,7 @@ class family_handler(): "fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ] }] - if base_model_type == "ti2v_2_2": + if test_wan_5B(base_model_type): download_def += [ { "repoId" : "DeepBeepMeep/Wan2.2", "sourceFolderList" : [""], @@ -377,8 +415,8 @@ class family_handler(): ui_defaults.update({ "sample_solver": "unipc", }) - if test_class_i2v(base_model_type): - ui_defaults["image_prompt_type"] = "S" + if test_class_i2v(base_model_type) and "S" in model_def["image_prompt_types_allowed"]: + ui_defaults["image_prompt_type"] = "S" if base_model_type in ["fantasy"]: ui_defaults.update({ @@ -434,10 +472,15 @@ class family_handler(): "image_prompt_type": "T", }) - if base_model_type in ["recam_1.3B"]: + if base_model_type in ["recam_1.3B", "lucy_edit"]: ui_defaults.update({ "video_prompt_type": "UV", }) + elif base_model_type in ["animate"]: + ui_defaults.update({ + "video_prompt_type": "PVBXAKI", + "mask_expand": 20, + }) if text_oneframe_overlap(base_model_type): ui_defaults["sliding_window_overlap"] = 1 diff --git a/shared/convert/convert_diffusers_to_flux.py b/shared/convert/convert_diffusers_to_flux.py new file mode 100644 index 0000000..608b176 --- /dev/null +++ b/shared/convert/convert_diffusers_to_flux.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python3 +""" +Convert a Flux model from Diffusers (folder or single-file) into the original +single-file Flux transformer checkpoint used by Black Forest Labs / ComfyUI. + +Input : /path/to/diffusers (root or .../transformer) OR /path/to/*.safetensors (single file) +Output : /path/to/flux1-your-model.safetensors (transformer only) + +Usage: + python diffusers_to_flux_transformer.py /path/to/diffusers /out/flux1-dev.safetensors + python diffusers_to_flux_transformer.py /path/to/diffusion_pytorch_model.safetensors /out/flux1-dev.safetensors + # optional quantization: + # --fp8 (float8_e4m3fn, simple) + # --fp8-scaled (scaled float8 for 2D weights; adds .scale_weight tensors) +""" + +import argparse +import json +from pathlib import Path +from collections import OrderedDict + +import torch +from safetensors import safe_open +import safetensors.torch +from tqdm import tqdm + + +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument("diffusers_path", type=str, + help="Path to Diffusers checkpoint folder OR a single .safetensors file.") + ap.add_argument("output_path", type=str, + help="Output .safetensors path for the Flux transformer.") + ap.add_argument("--fp8", action="store_true", + help="Experimental: write weights as float8_e4m3fn via stochastic rounding (transformer only).") + ap.add_argument("--fp8-scaled", action="store_true", + help="Experimental: scaled float8_e4m3fn for 2D weight tensors; adds .scale_weight tensors.") + return ap.parse_args() + + +# Mapping from original Flux keys -> list of Diffusers keys (per block where applicable). +DIFFUSERS_MAP = { + # global embeds + "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], + "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], + "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], + "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], + + "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], + "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], + "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], + "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], + + "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], + "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], + "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], + "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], + + "txt_in.weight": ["context_embedder.weight"], + "txt_in.bias": ["context_embedder.bias"], + "img_in.weight": ["x_embedder.weight"], + "img_in.bias": ["x_embedder.bias"], + + # dual-stream (image/text) blocks + "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], + "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], + "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], + "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], + + "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], + "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], + "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], + "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], + + "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], + "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], + "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], + "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], + + "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], + "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], + "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], + "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], + + "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], + "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], + "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], + "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], + + "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], + "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], + "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], + "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], + + # single-stream blocks + "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], + "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], + "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], + "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], + "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], + "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().linear2.bias": ["proj_out.bias"], + + # final + "final_layer.linear.weight": ["proj_out.weight"], + "final_layer.linear.bias": ["proj_out.bias"], + # these two are built from norm_out.linear.{weight,bias} by swapping [shift,scale] -> [scale,shift] + "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], + "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], +} + + +class DiffusersSource: + """ + Uniform interface over: + 1) Folder with index JSON + shards + 2) Folder with exactly one .safetensors (no index) + 3) Single .safetensors file + Provides .has(key), .get(key)->Tensor, .base_keys (keys with 'model.' stripped for scanning) + """ + + POSSIBLE_PREFIXES = ["", "model."] # try in this order + + def __init__(self, path: Path): + p = Path(path) + if p.is_dir(): + # use 'transformer' subfolder if present + if (p / "transformer").is_dir(): + p = p / "transformer" + self._init_from_dir(p) + elif p.is_file() and p.suffix == ".safetensors": + self._init_from_single_file(p) + else: + raise FileNotFoundError(f"Invalid path: {p}") + + # ---------- common helpers ---------- + + @staticmethod + def _strip_prefix(k: str) -> str: + return k[6:] if k.startswith("model.") else k + + def _resolve(self, want: str): + """ + Return the actual stored key matching `want` by trying known prefixes. + """ + for pref in self.POSSIBLE_PREFIXES: + k = pref + want + if k in self._all_keys: + return k + return None + + def has(self, want: str) -> bool: + return self._resolve(want) is not None + + def get(self, want: str) -> torch.Tensor: + real_key = self._resolve(want) + if real_key is None: + raise KeyError(f"Missing key: {want}") + return self._get_by_real_key(real_key).to("cpu") + + @property + def base_keys(self): + # keys without 'model.' prefix for scanning + return [self._strip_prefix(k) for k in self._all_keys] + + # ---------- modes ---------- + + def _init_from_single_file(self, file_path: Path): + self._mode = "single" + self._file = file_path + self._handle = safe_open(file_path, framework="pt", device="cpu") + self._all_keys = list(self._handle.keys()) + + def _get_by_real_key(real_key: str): + return self._handle.get_tensor(real_key) + + self._get_by_real_key = _get_by_real_key + + def _init_from_dir(self, dpath: Path): + index_json = dpath / "diffusion_pytorch_model.safetensors.index.json" + if index_json.exists(): + with open(index_json, "r", encoding="utf-8") as f: + index = json.load(f) + weight_map = index["weight_map"] # full mapping + self._mode = "sharded" + self._dpath = dpath + self._weight_map = {k: dpath / v for k, v in weight_map.items()} + self._all_keys = list(self._weight_map.keys()) + self._open_handles = {} + + def _get_by_real_key(real_key: str): + fpath = self._weight_map[real_key] + h = self._open_handles.get(fpath) + if h is None: + h = safe_open(fpath, framework="pt", device="cpu") + self._open_handles[fpath] = h + return h.get_tensor(real_key) + + self._get_by_real_key = _get_by_real_key + return + + # no index: try exactly one safetensors in folder + files = sorted(dpath.glob("*.safetensors")) + if len(files) != 1: + raise FileNotFoundError( + f"No index found and {dpath} does not contain exactly one .safetensors file." + ) + self._init_from_single_file(files[0]) + + +def main(): + args = parse_args() + src = DiffusersSource(Path(args.diffusers_path)) + + # Count blocks by scanning base keys (with any 'model.' prefix removed) + num_dual = 0 + num_single = 0 + for k in src.base_keys: + if k.startswith("transformer_blocks."): + try: + i = int(k.split(".")[1]) + num_dual = max(num_dual, i + 1) + except Exception: + pass + elif k.startswith("single_transformer_blocks."): + try: + i = int(k.split(".")[1]) + num_single = max(num_single, i + 1) + except Exception: + pass + print(f"Found {num_dual} dual-stream blocks, {num_single} single-stream blocks") + + # Swap [shift, scale] -> [scale, shift] (weights are concatenated along dim=0) + def swap_scale_shift(vec: torch.Tensor) -> torch.Tensor: + shift, scale = vec.chunk(2, dim=0) + return torch.cat([scale, shift], dim=0) + + orig = {} + + # Per-block (dual) + for b in range(num_dual): + prefix = f"transformer_blocks.{b}." + for okey, dvals in DIFFUSERS_MAP.items(): + if not okey.startswith("double_blocks."): + continue + dkeys = [prefix + v for v in dvals] + if not all(src.has(k) for k in dkeys): + continue + if len(dkeys) == 1: + orig[okey.replace("()", str(b))] = src.get(dkeys[0]) + else: + orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0) + + # Per-block (single) + for b in range(num_single): + prefix = f"single_transformer_blocks.{b}." + for okey, dvals in DIFFUSERS_MAP.items(): + if not okey.startswith("single_blocks."): + continue + dkeys = [prefix + v for v in dvals] + if not all(src.has(k) for k in dkeys): + continue + if len(dkeys) == 1: + orig[okey.replace("()", str(b))] = src.get(dkeys[0]) + else: + orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0) + + # Globals (non-block) + for okey, dvals in DIFFUSERS_MAP.items(): + if okey.startswith(("double_blocks.", "single_blocks.")): + continue + dkeys = dvals + if not all(src.has(k) for k in dkeys): + continue + if len(dkeys) == 1: + orig[okey] = src.get(dkeys[0]) + else: + orig[okey] = torch.cat([src.get(k) for k in dkeys], dim=0) + + # Fix final_layer.adaLN_modulation.1.{weight,bias} by swapping scale/shift halves + if "final_layer.adaLN_modulation.1.weight" in orig: + orig["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift( + orig["final_layer.adaLN_modulation.1.weight"] + ) + if "final_layer.adaLN_modulation.1.bias" in orig: + orig["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift( + orig["final_layer.adaLN_modulation.1.bias"] + ) + + # Optional FP8 variants (experimental; not required for ComfyUI/BFL) + if args.fp8 or args.fp8_scaled: + dtype = torch.float8_e4m3fn # noqa + minv, maxv = torch.finfo(dtype).min, torch.finfo(dtype).max + + def stochastic_round_to(t): + t = t.float().clamp(minv, maxv) + lower = torch.floor(t * 256) / 256 + upper = torch.ceil(t * 256) / 256 + prob = torch.where(upper != lower, (t - lower) / (upper - lower), torch.zeros_like(t)) + rnd = torch.rand_like(t) + out = torch.where(rnd < prob, upper, lower) + return out.to(dtype) + + def scale_to_8bit(weight, target_max=416.0): + absmax = weight.abs().max() + scale = absmax / target_max if absmax > 0 else torch.tensor(1.0) + scaled = (weight / scale).clamp(minv, maxv).to(dtype) + return scaled, scale + + scales = {} + for k in tqdm(list(orig.keys()), desc="Quantizing to fp8"): + t = orig[k] + if args.fp8: + orig[k] = stochastic_round_to(t) + else: + if k.endswith(".weight") and t.dim() == 2: + qt, s = scale_to_8bit(t) + orig[k] = qt + scales[k[:-len(".weight")] + ".scale_weight"] = s + else: + orig[k] = t.clamp(minv, maxv).to(dtype) + if args.fp8_scaled: + orig.update(scales) + orig["scaled_fp8"] = torch.tensor([], dtype=dtype) + else: + # Default: save in bfloat16 + for k in list(orig.keys()): + orig[k] = orig[k].to(torch.bfloat16).cpu() + + out_path = Path(args.output_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + meta = OrderedDict() + meta["format"] = "pt" + meta["modelspec.date"] = __import__("datetime").date.today().strftime("%Y-%m-%d") + print(f"Saving transformer to: {out_path}") + safetensors.torch.save_file(orig, str(out_path), metadata=meta) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/shared/inpainting/__init__.py b/shared/inpainting/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/shared/inpainting/lanpaint.py b/shared/inpainting/lanpaint.py new file mode 100644 index 0000000..3165e7b --- /dev/null +++ b/shared/inpainting/lanpaint.py @@ -0,0 +1,240 @@ +import torch +from .utils import * +from functools import partial + +# Many thanks to the LanPaint team for this implementation (https://github.com/scraed/LanPaint/) + +def _pack_latents(latents): + batch_size, num_channels_latents, _, height, width = latents.shape + + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + +def _unpack_latents(latents, height, width, vae_scale_factor=8): + batch_size, num_patches, channels = latents.shape + + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + +class LanPaint(): + def __init__(self, NSteps = 5, Friction = 15, Lambda = 8, Beta = 1, StepSize = 0.15, IS_FLUX = True, IS_FLOW = False): + self.n_steps = NSteps + self.chara_lamb = Lambda + self.IS_FLUX = IS_FLUX + self.IS_FLOW = IS_FLOW + self.step_size = StepSize + self.friction = Friction + self.chara_beta = Beta + self.img_dim_size = None + def add_none_dims(self, array): + # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times + index = (slice(None),) + (None,) * (self.img_dim_size-1) + return array[index] + def remove_none_dims(self, array): + # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times + index = (slice(None),) + (0,) * (self.img_dim_size-1) + return array[index] + def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_image, noise, sigma, latent_mask, n_steps=None, height =720, width = 1280, vae_scale_factor = 8): + latent_image = _unpack_latents(latent_image, height=height, width=width, vae_scale_factor=vae_scale_factor) + noise = _unpack_latents(noise, height=height, width=width, vae_scale_factor=vae_scale_factor) + x = _unpack_latents(x, height=height, width=width, vae_scale_factor=vae_scale_factor) + latent_mask = _unpack_latents(latent_mask, height=height, width=width, vae_scale_factor=vae_scale_factor) + self.height = height + self.width = width + self.vae_scale_factor = vae_scale_factor + self.img_dim_size = len(x.shape) + self.latent_image = latent_image + self.noise = noise + if n_steps is None: + n_steps = self.n_steps + out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW) + out = _pack_latents(out) + return out + def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW): + if IS_FLUX: + cfg_BIG = 1.0 + + def double_denoise(latents, t): + latents = _pack_latents(latents) + noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale) + if noise_pred == None: return None, None + predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t) + predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor) + if true_cfg_scale == cfg_BIG: + predict_big = predict_std + else: + predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t) + predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor) + return predict_std, predict_big + + if len(sigma.shape) == 0: + sigma = torch.tensor([sigma.item()]) + latent_mask = 1 - latent_mask + if IS_FLUX or IS_FLOW: + Flow_t = sigma + abt = (1 - Flow_t)**2 / ((1 - Flow_t)**2 + Flow_t**2 ) + VE_Sigma = Flow_t / (1 - Flow_t) + #print("t", torch.mean( sigma ).item(), "VE_Sigma", torch.mean( VE_Sigma ).item()) + else: + VE_Sigma = sigma + abt = 1/( 1+VE_Sigma**2 ) + Flow_t = (1-abt)**0.5 / ( (1-abt)**0.5 + abt**0.5 ) + # VE_Sigma, abt, Flow_t = current_times + current_times = (VE_Sigma, abt, Flow_t) + + step_size = self.step_size * (1 - abt) + step_size = self.add_none_dims(step_size) + # self.inner_model.inner_model.scale_latent_inpaint returns variance exploding x_t values + # This is the replace step + # x = x * (1 - latent_mask) + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image)* latent_mask + + noisy_image = self.latent_image * (1.0 - sigma) + self.noise * sigma + x = x * (1 - latent_mask) + noisy_image * latent_mask + + if IS_FLUX or IS_FLOW: + x_t = x * ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 ) + else: + x_t = x / ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values + + ############ LanPaint Iterations Start ############### + # after noise_scaling, noise = latent_image + noise * sigma, which is x_t in the variance exploding diffusion model notation for the known region. + args = None + for i in range(n_steps): + score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise ) + if score_func is None: return None + x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args) + if IS_FLUX or IS_FLOW: + x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 ) + else: + x = x_t * ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values + ############ LanPaint Iterations End ############### + # out is x_0 + # out, _ = self.inner_model(x, sigma, model_options=model_options, seed=seed) + # out = out * (1-latent_mask) + self.latent_image * latent_mask + # return out + return x + + def score_model(self, x_t, y, mask, abt, sigma, tflow, denoise_func): + + lamb = self.chara_lamb + if self.IS_FLUX or self.IS_FLOW: + # compute t for flow model, with a small epsilon compensating for numerical error. + x = x_t / ( abt**0.5 + (1-abt)**0.5 ) # switch to Gaussian flow matching + x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(tflow)) + if x_0 is None: return None + else: + x = x_t * ( 1+sigma**2 )**0.5 # switch to variance exploding + x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(sigma)) + if x_0 is None: return None + + score_x = -(x_t - x_0) + score_y = - (1 + lamb) * ( x_t - y ) + lamb * (x_t - x_0_BIG) + return score_x * (1 - mask) + score_y * mask + def sigma_x(self, abt): + # the time scale for the x_t update + return abt**0 + def sigma_y(self, abt): + beta = self.chara_beta * abt ** 0 + return beta + + def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=1, sigma_y=0, args=None): + # prepare the step size and time parameters + with torch.autocast(device_type=x_t.device.type, dtype=torch.float32): + step_sizes = self.prepare_step_size(current_times, step_size, sigma_x, sigma_y) + sigma, abt, dtx, dty, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y = step_sizes + # print('mask',mask.device) + if torch.mean(dtx) <= 0.: + return x_t, args + # ------------------------------------------------------------------------- + # Compute the Langevin dynamics update in variance perserving notation + # ------------------------------------------------------------------------- + #x0 = self.x0_evalutation(x_t, score, sigma, args) + #C = abt**0.5 * x0 / (1-abt) + A = A_x * (1-mask) + A_y * mask + D = D_x * (1-mask) + D_y * mask + dt = dtx * (1-mask) + dty * mask + Gamma = Gamma_x * (1-mask) + Gamma_y * mask + + + def Coef_C(x_t): + x0 = self.x0_evalutation(x_t, score, sigma, args) + C = (abt**0.5 * x0 - x_t )/ (1-abt) + A * x_t + return C + def advance_time(x_t, v, dt, Gamma, A, C, D): + dtype = x_t.dtype + with torch.autocast(device_type=x_t.device.type, dtype=torch.float32): + osc = StochasticHarmonicOscillator(Gamma, A, C, D ) + x_t, v = osc.dynamics(x_t, v, dt ) + x_t = x_t.to(dtype) + v = v.to(dtype) + return x_t, v + if args is None: + #v = torch.zeros_like(x_t) + v = None + C = Coef_C(x_t) + #print(torch.squeeze(dtx), torch.squeeze(dty)) + x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D) + else: + v, C = args + + x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D) + + C_new = Coef_C(x_t) + v = v + Gamma**0.5 * ( C_new - C) *dt + + x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D) + + C = C_new + + return x_t, (v, C) + + def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y): + # ------------------------------------------------------------------------- + # Unpack current times parameters (sigma and abt) + sigma, abt, flow_t = current_times + sigma = self.add_none_dims(sigma) + abt = self.add_none_dims(abt) + # Compute time step (dtx, dty) for x and y branches. + dtx = 2 * step_size * sigma_x + dty = 2 * step_size * sigma_y + + # ------------------------------------------------------------------------- + # Define friction parameter Gamma_hat for each branch. + # Using dtx**0 provides a tensor of the proper device/dtype. + + Gamma_hat_x = self.friction **2 * self.step_size * sigma_x / 0.1 * sigma**0 + Gamma_hat_y = self.friction **2 * self.step_size * sigma_y / 0.1 * sigma**0 + #print("Gamma_hat_x", torch.mean(Gamma_hat_x).item(), "Gamma_hat_y", torch.mean(Gamma_hat_y).item()) + # adjust dt to match denoise-addnoise steps sizes + Gamma_hat_x /= 2. + Gamma_hat_y /= 2. + A_t_x = (1) / ( 1 - abt ) * dtx / 2 + A_t_y = (1+self.chara_lamb) / ( 1 - abt ) * dty / 2 + + + A_x = A_t_x / (dtx/2) + A_y = A_t_y / (dty/2) + Gamma_x = Gamma_hat_x / (dtx/2) + Gamma_y = Gamma_hat_y / (dty/2) + + #D_x = (2 * (1 + sigma**2) )**0.5 + #D_y = (2 * (1 + sigma**2) )**0.5 + D_x = (2 * abt**0 )**0.5 + D_y = (2 * abt**0 )**0.5 + return sigma, abt, dtx/2, dty/2, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y + + + + def x0_evalutation(self, x_t, score, sigma, args): + x0 = x_t + score(x_t) + return x0 \ No newline at end of file diff --git a/shared/inpainting/utils.py b/shared/inpainting/utils.py new file mode 100644 index 0000000..c017ab0 --- /dev/null +++ b/shared/inpainting/utils.py @@ -0,0 +1,301 @@ +import torch +def epxm1_x(x): + # Compute the (exp(x) - 1) / x term with a small value to avoid division by zero. + result = torch.special.expm1(x) / x + # replace NaN or inf values with 0 + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + mask = torch.abs(x) < 1e-2 + result = torch.where(mask, 1 + x/2. + x**2 / 6., result) + return result +def epxm1mx_x2(x): + # Compute the (exp(x) - 1 - x) / x**2 term with a small value to avoid division by zero. + result = (torch.special.expm1(x) - x) / x**2 + # replace NaN or inf values with 0 + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + mask = torch.abs(x**2) < 1e-2 + result = torch.where(mask, 1/2. + x/6 + x**2 / 24 + x**3 / 120, result) + return result + +def expm1mxmhx2_x3(x): + # Compute the (exp(x) - 1 - x - x**2 / 2) / x**3 term with a small value to avoid division by zero. + result = (torch.special.expm1(x) - x - x**2 / 2) / x**3 + # replace NaN or inf values with 0 + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + mask = torch.abs(x**3) < 1e-2 + result = torch.where(mask, 1/6 + x/24 + x**2 / 120 + x**3 / 720 + x**4 / 5040, result) + return result + +def exp_1mcosh_GD(gamma_t, delta): + """ + Compute e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + # Main computation + is_positive = delta > 0 + sqrt_abs_delta = torch.sqrt(torch.abs(delta)) + gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta + numerator_pos = torch.exp(-gamma_t) - (torch.exp(gamma_t * (sqrt_abs_delta - 1)) + torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2 + numerator_neg = torch.exp(-gamma_t) * ( 1 - torch.cos(gamma_t * sqrt_abs_delta ) ) + numerator = torch.where(is_positive, numerator_pos, numerator_neg) + result = numerator / (delta * gamma_t**2 ) + # Handle NaN/inf cases + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + # Handle numerical instability for small delta + mask = torch.abs(gamma_t_sqrt_delta**2) < 5e-2 + taylor = ( -0.5 - gamma_t**2 / 24 * delta - gamma_t**4 / 720 * delta**2 ) * torch.exp(-gamma_t) + result = torch.where(mask, taylor, result) + return result + +def exp_sinh_GsqrtD(gamma_t, delta): + """ + Compute e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + # Main computation + is_positive = delta > 0 + sqrt_abs_delta = torch.sqrt(torch.abs(delta)) + gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta + numerator_pos = (torch.exp(gamma_t * (sqrt_abs_delta - 1)) - torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2 + denominator_pos = gamma_t_sqrt_delta + result_pos = numerator_pos / gamma_t_sqrt_delta + result_pos = torch.where(torch.isfinite(result_pos), result_pos, torch.zeros_like(result_pos)) + + # Taylor expansion for small gamma_t_sqrt_delta + mask = torch.abs(gamma_t_sqrt_delta) < 1e-2 + taylor = ( 1 + gamma_t**2 / 6 * delta + gamma_t**4 / 120 * delta**2 ) * torch.exp(-gamma_t) + result_pos = torch.where(mask, taylor, result_pos) + + # Handle negative delta + result_neg = torch.exp(-gamma_t) * torch.special.sinc(gamma_t_sqrt_delta/torch.pi) + result = torch.where(is_positive, result_pos, result_neg) + return result + +def exp_cosh(gamma_t, delta): + """ + Compute e^(-Γt) * cosh(Γt√Δ) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + exp_1mcosh_GD_result = exp_1mcosh_GD(gamma_t, delta) # e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) + result = torch.exp(-gamma_t) - gamma_t**2 * delta * exp_1mcosh_GD_result + return result +def exp_sinh_sqrtD(gamma_t, delta): + """ + Compute e^(-Γt) * sinh(Γt√Δ) / √Δ + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + Returns: + Result of the computation with numerical stability handling + """ + exp_sinh_GsqrtD_result = exp_sinh_GsqrtD(gamma_t, delta) # e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ) + result = gamma_t * exp_sinh_GsqrtD_result + return result + + + +def zeta1(gamma_t, delta): + # Compute hyperbolic terms and exponential + half_gamma_t = gamma_t / 2 + exp_cosh_term = exp_cosh(half_gamma_t, delta) + exp_sinh_term = exp_sinh_sqrtD(half_gamma_t, delta) + + + # Main computation + numerator = 1 - (exp_cosh_term + exp_sinh_term) + denominator = gamma_t * (1 - delta) / 4 + result = 1 - numerator / denominator + + # Handle numerical instability + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + + # Taylor expansion for small x (similar to your epxm1Dx approach) + mask = torch.abs(denominator) < 5e-3 + term1 = epxm1_x(-gamma_t) + term2 = epxm1mx_x2(-gamma_t) + term3 = expm1mxmhx2_x3(-gamma_t) + taylor = term1 + (1/2.+ term1-3*term2)*denominator + (-1/6. + term1/2 - 4 * term2 + 10 * term3) * denominator**2 + result = torch.where(mask, taylor, result) + + return result + +def exp_cosh_minus_terms(gamma_t, delta): + """ + Compute E^(-tΓ) * (Cosh[tΓ] - 1 - (Cosh[tΓ√Δ] - 1)/Δ) / (tΓ(1 - Δ)) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + exp_term = torch.exp(-gamma_t) + # Compute individual terms + exp_cosh_term = exp_cosh(gamma_t, gamma_t**0) - exp_term # E^(-tΓ) (Cosh[tΓ] - 1) term + exp_cosh_delta_term = - gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) # E^(-tΓ) (Cosh[tΓ√Δ] - 1)/Δ term + + #exp_1mcosh_GD e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) + # Main computation + numerator = exp_cosh_term - exp_cosh_delta_term + denominator = gamma_t * (1 - delta) + + result = numerator / denominator + + # Handle numerical instability + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + + # Taylor expansion for small gamma_t and delta near 1 + mask = (torch.abs(denominator) < 1e-1) + exp_1mcosh_GD_term = exp_1mcosh_GD(gamma_t, delta**0) + taylor = ( + gamma_t*exp_1mcosh_GD_term + 0.5 * gamma_t * exp_sinh_GsqrtD(gamma_t, delta**0) + - denominator / 4 * ( 0.5 * exp_cosh(gamma_t, delta**0) - 4 * exp_1mcosh_GD_term - 5 /2 * exp_sinh_GsqrtD(gamma_t, delta**0) ) + ) + result = torch.where(mask, taylor, result) + + return result + + +def zeta2(gamma_t, delta): + half_gamma_t = gamma_t / 2 + return exp_sinh_GsqrtD(half_gamma_t, delta) + +def sig11(gamma_t, delta): + return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta) + + +def Zcoefs(gamma_t, delta): + Zeta1 = zeta1(gamma_t, delta) + Zeta2 = zeta2(gamma_t, delta) + + sq_total = 1 - Zeta1 + gamma_t * (delta - 1) * (Zeta1 - 1)**2 / 8 + amplitude = torch.sqrt(sq_total) + Zcoef1 = ( gamma_t**0.5 * Zeta2 / 2 **0.5 ) / amplitude + Zcoef2 = Zcoef1 * gamma_t *( - 2 * exp_1mcosh_GD(gamma_t, delta) / sig11(gamma_t, delta) ) ** 0.5 + #cterm = exp_cosh_minus_terms(gamma_t, delta) + #sterm = exp_sinh_sqrtD(gamma_t, delta**0) + exp_sinh_sqrtD(gamma_t, delta) + #Zcoef3 = 2 * torch.sqrt( cterm / ( gamma_t * (1 - delta) * cterm + sterm ) ) + Zcoef3 = torch.sqrt( torch.maximum(1 - Zcoef1**2 - Zcoef2**2, sq_total.new_zeros(sq_total.shape)) ) + + return Zcoef1 * amplitude, Zcoef2 * amplitude, Zcoef3 * amplitude, amplitude + +def Zcoefs_asymp(gamma_t, delta): + A_t = (gamma_t * (1 - delta) )/4 + return epxm1_x(- 2 * A_t) + +class StochasticHarmonicOscillator: + """ + Simulates a stochastic harmonic oscillator governed by the equations: + dy(t) = q(t) dt + dq(t) = -Γ A y(t) dt + Γ C dt + Γ D dw(t) - Γ q(t) dt + + Also define v(t) = q(t) / √Γ, which is numerically more stable. + + Where: + y(t) - Position variable + q(t) - Velocity variable + Γ - Damping coefficient + A - Harmonic potential strength + C - Constant force term + D - Noise amplitude + dw(t) - Wiener process (Brownian motion) + """ + def __init__(self, Gamma, A, C, D): + self.Gamma = Gamma + self.A = A + self.C = C + self.D = D + self.Delta = 1 - 4 * A / Gamma + def sig11(self, gamma_t, delta): + return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta) + def sig22(self, gamma_t, delta): + return 1- zeta1(2*gamma_t, delta) + 2 * gamma_t * exp_1mcosh_GD(gamma_t, delta) + def dynamics(self, y0, v0, t): + """ + Calculates the position and velocity variables at time t. + + Parameters: + y0 (float): Initial position + v0 (float): Initial velocity v(0) = q(0) / √Γ + t (float): Time at which to evaluate the dynamics + Returns: + tuple: (y(t), v(t)) + """ + + dummyzero = y0.new_zeros(1) # convert scalar to tensor with same device and dtype as y0 + Delta = self.Delta + dummyzero + Gamma_hat = self.Gamma * t + dummyzero + A = self.A + dummyzero + C = self.C + dummyzero + D = self.D + dummyzero + Gamma = self.Gamma + dummyzero + zeta_1 = zeta1( Gamma_hat, Delta) + zeta_2 = zeta2( Gamma_hat, Delta) + EE = 1 - Gamma_hat * zeta_2 + + if v0 is None: + v0 = torch.randn_like(y0) * D / 2 ** 0.5 + #v0 = (C - A * y0)/Gamma**0.5 + + # Calculate mean position and velocity + term1 = (1 - zeta_1) * (C * t - A * t * y0) + zeta_2 * (Gamma ** 0.5) * v0 * t + y_mean = term1 + y0 + v_mean = (1 - EE)*(C - A * y0) / (Gamma ** 0.5) + (EE - A * t * (1 - zeta_1)) * v0 + + cov_yy = D**2 * t * self.sig22(Gamma_hat, Delta) + cov_vv = D**2 * self.sig11(Gamma_hat, Delta) / 2 + cov_yv = (zeta2(Gamma_hat, Delta) * Gamma_hat * D ) **2 / 2 / (Gamma ** 0.5) + + # sample new position and velocity with multivariate normal distribution + + batch_shape = y0.shape + cov_matrix = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype) + cov_matrix[..., 0, 0] = cov_yy + cov_matrix[..., 0, 1] = cov_yv + cov_matrix[..., 1, 0] = cov_yv # symmetric + cov_matrix[..., 1, 1] = cov_vv + + + + # Compute the Cholesky decomposition to get scale_tril + #scale_tril = torch.linalg.cholesky(cov_matrix) + scale_tril = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype) + tol = 1e-8 + cov_yy = torch.clamp( cov_yy, min = tol ) + sd_yy = torch.sqrt( cov_yy ) + inv_sd_yy = 1/(sd_yy) + + scale_tril[..., 0, 0] = sd_yy + scale_tril[..., 0, 1] = 0. + scale_tril[..., 1, 0] = cov_yv * inv_sd_yy + scale_tril[..., 1, 1] = torch.clamp( cov_vv - cov_yv**2 / cov_yy, min = tol ) ** 0.5 + # check if it matches torch.linalg. + #assert torch.allclose(torch.linalg.cholesky(cov_matrix), scale_tril, atol = 1e-4, rtol = 1e-4 ) + # Sample correlated noise from multivariate normal + mean = torch.zeros(*batch_shape, 2, device=y0.device, dtype=y0.dtype) + mean[..., 0] = y_mean + mean[..., 1] = v_mean + new_yv = torch.distributions.MultivariateNormal( + loc=mean, + scale_tril=scale_tril + ).sample() + + return new_yv[...,0], new_yv[...,1] \ No newline at end of file diff --git a/shared/utils/audio_video.py b/shared/utils/audio_video.py index b24530d..224cf34 100644 --- a/shared/utils/audio_video.py +++ b/shared/utils/audio_video.py @@ -232,6 +232,9 @@ def save_video(tensor, retry=5): """Save tensor as video with configurable codec and container options.""" + if torch.is_tensor(tensor) and len(tensor.shape) == 4: + tensor = tensor.unsqueeze(0) + suffix = f'.{container}' cache_file = osp.join('/tmp', rand_name(suffix=suffix)) if save_file is None else save_file if not cache_file.endswith(suffix): diff --git a/shared/utils/download.py b/shared/utils/download.py new file mode 100644 index 0000000..ed035c0 --- /dev/null +++ b/shared/utils/download.py @@ -0,0 +1,110 @@ +import sys, time + +# Global variables to track download progress +_start_time = None +_last_time = None +_last_downloaded = 0 +_speed_history = [] +_update_interval = 0.5 # Update speed every 0.5 seconds + +def progress_hook(block_num, block_size, total_size, filename=None): + """ + Simple progress bar hook for urlretrieve + + Args: + block_num: Number of blocks downloaded so far + block_size: Size of each block in bytes + total_size: Total size of the file in bytes + filename: Name of the file being downloaded (optional) + """ + global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval + + current_time = time.time() + downloaded = block_num * block_size + + # Initialize timing on first call + if _start_time is None or block_num == 0: + _start_time = current_time + _last_time = current_time + _last_downloaded = 0 + _speed_history = [] + + # Calculate download speed only at specified intervals + speed = 0 + if current_time - _last_time >= _update_interval: + if _last_time > 0: + current_speed = (downloaded - _last_downloaded) / (current_time - _last_time) + _speed_history.append(current_speed) + # Keep only last 5 speed measurements for smoothing + if len(_speed_history) > 5: + _speed_history.pop(0) + # Average the recent speeds for smoother display + speed = sum(_speed_history) / len(_speed_history) + + _last_time = current_time + _last_downloaded = downloaded + elif _speed_history: + # Use the last calculated average speed + speed = sum(_speed_history) / len(_speed_history) + # Format file sizes and speed + def format_bytes(bytes_val): + for unit in ['B', 'KB', 'MB', 'GB']: + if bytes_val < 1024: + return f"{bytes_val:.1f}{unit}" + bytes_val /= 1024 + return f"{bytes_val:.1f}TB" + + file_display = filename if filename else "Unknown file" + + if total_size <= 0: + # If total size is unknown, show downloaded bytes + speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" + line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}" + # Clear any trailing characters by padding with spaces + sys.stdout.write(line.ljust(80)) + sys.stdout.flush() + return + + downloaded = block_num * block_size + percent = min(100, (downloaded / total_size) * 100) + + # Create progress bar (40 characters wide to leave room for other info) + bar_length = 40 + filled = int(bar_length * percent / 100) + bar = '█' * filled + '░' * (bar_length - filled) + + # Format file sizes and speed + def format_bytes(bytes_val): + for unit in ['B', 'KB', 'MB', 'GB']: + if bytes_val < 1024: + return f"{bytes_val:.1f}{unit}" + bytes_val /= 1024 + return f"{bytes_val:.1f}TB" + + speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" + + # Display progress with filename first + line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}" + # Clear any trailing characters by padding with spaces + sys.stdout.write(line.ljust(100)) + sys.stdout.flush() + + # Print newline when complete + if percent >= 100: + print() + +# Wrapper function to include filename in progress hook +def create_progress_hook(filename): + """Creates a progress hook with the filename included""" + global _start_time, _last_time, _last_downloaded, _speed_history + # Reset timing variables for new download + _start_time = None + _last_time = None + _last_downloaded = 0 + _speed_history = [] + + def hook(block_num, block_size, total_size): + return progress_hook(block_num, block_size, total_size, filename) + return hook + + diff --git a/shared/utils/utils.py b/shared/utils/utils.py index 17b9dde..6e8a98b 100644 --- a/shared/utils/utils.py +++ b/shared/utils/utils.py @@ -1,4 +1,3 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import os import os.path as osp @@ -176,8 +175,9 @@ def remove_background(img, session=None): def convert_image_to_tensor(image): return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0) -def convert_tensor_to_image(t, frame_no = -1): - t = t[:, frame_no] if frame_no >= 0 else t +def convert_tensor_to_image(t, frame_no = 0): + if len(t.shape) == 4: + t = t[:, frame_no] return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) def save_image(tensor_image, name, frame_no = -1): @@ -257,16 +257,18 @@ def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fi image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) return image, new_height, new_width -def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None ): +def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5 ): if rm_background: session = new_session() output_list =[] + output_mask_list =[] for i, img in enumerate(img_list): width, height = img.size + resized_mask = None if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2: - if outpainting_dims is not None: - resized_image =img + if outpainting_dims is not None and background_ref_outpainted: + resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True) elif img.size != (budget_width, budget_height): resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS) else: @@ -290,145 +292,103 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg # resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200, - return output_list + output_mask_list.append(resized_mask) + return output_list, output_mask_list +def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu", full_frame = False, outpainting_dims = None, return_mask = False, return_image = False): + from shared.utils.utils import save_image + inpaint_color = canvas_tf_bg / 127.5 - 1 - - -def str2bool(v): - """ - Convert a string to a boolean. - - Supported true values: 'yes', 'true', 't', 'y', '1' - Supported false values: 'no', 'false', 'f', 'n', '0' - - Args: - v (str): String to convert. - - Returns: - bool: Converted boolean value. - - Raises: - argparse.ArgumentTypeError: If the value cannot be converted to boolean. - """ - if isinstance(v, bool): - return v - v_lower = v.lower() - if v_lower in ('yes', 'true', 't', 'y', '1'): - return True - elif v_lower in ('no', 'false', 'f', 'n', '0'): - return False + ref_width, ref_height = ref_img.size + if (ref_height, ref_width) == image_size and outpainting_dims == None: + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + canvas = torch.zeros_like(ref_img) if return_mask else None else: - raise argparse.ArgumentTypeError('Boolean value expected (True/False)') - - -import sys, time - -# Global variables to track download progress -_start_time = None -_last_time = None -_last_downloaded = 0 -_speed_history = [] -_update_interval = 0.5 # Update speed every 0.5 seconds - -def progress_hook(block_num, block_size, total_size, filename=None): - """ - Simple progress bar hook for urlretrieve - - Args: - block_num: Number of blocks downloaded so far - block_size: Size of each block in bytes - total_size: Total size of the file in bytes - filename: Name of the file being downloaded (optional) - """ - global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval - - current_time = time.time() - downloaded = block_num * block_size - - # Initialize timing on first call - if _start_time is None or block_num == 0: - _start_time = current_time - _last_time = current_time - _last_downloaded = 0 - _speed_history = [] - - # Calculate download speed only at specified intervals - speed = 0 - if current_time - _last_time >= _update_interval: - if _last_time > 0: - current_speed = (downloaded - _last_downloaded) / (current_time - _last_time) - _speed_history.append(current_speed) - # Keep only last 5 speed measurements for smoothing - if len(_speed_history) > 5: - _speed_history.pop(0) - # Average the recent speeds for smoother display - speed = sum(_speed_history) / len(_speed_history) - - _last_time = current_time - _last_downloaded = downloaded - elif _speed_history: - # Use the last calculated average speed - speed = sum(_speed_history) / len(_speed_history) - # Format file sizes and speed - def format_bytes(bytes_val): - for unit in ['B', 'KB', 'MB', 'GB']: - if bytes_val < 1024: - return f"{bytes_val:.1f}{unit}" - bytes_val /= 1024 - return f"{bytes_val:.1f}TB" - - file_display = filename if filename else "Unknown file" - - if total_size <= 0: - # If total size is unknown, show downloaded bytes - speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" - line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}" - # Clear any trailing characters by padding with spaces - sys.stdout.write(line.ljust(80)) - sys.stdout.flush() - return - - downloaded = block_num * block_size - percent = min(100, (downloaded / total_size) * 100) - - # Create progress bar (40 characters wide to leave room for other info) - bar_length = 40 - filled = int(bar_length * percent / 100) - bar = '█' * filled + '░' * (bar_length - filled) - - # Format file sizes and speed - def format_bytes(bytes_val): - for unit in ['B', 'KB', 'MB', 'GB']: - if bytes_val < 1024: - return f"{bytes_val:.1f}{unit}" - bytes_val /= 1024 - return f"{bytes_val:.1f}TB" - - speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" - - # Display progress with filename first - line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}" - # Clear any trailing characters by padding with spaces - sys.stdout.write(line.ljust(100)) - sys.stdout.flush() - - # Print newline when complete - if percent >= 100: - print() - -# Wrapper function to include filename in progress hook -def create_progress_hook(filename): - """Creates a progress hook with the filename included""" - global _start_time, _last_time, _last_downloaded, _speed_history - # Reset timing variables for new download - _start_time = None - _last_time = None - _last_downloaded = 0 - _speed_history = [] - - def hook(block_num, block_size, total_size): - return progress_hook(block_num, block_size, total_size, filename) - return hook + if outpainting_dims != None: + final_height, final_width = image_size + canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1) + else: + canvas_height, canvas_width = image_size + if full_frame: + new_height = canvas_height + new_width = canvas_width + top = left = 0 + else: + # if fill_max and (canvas_height - new_height) < 16: + # new_height = canvas_height + # if fill_max and (canvas_width - new_width) < 16: + # new_width = canvas_width + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + if outpainting_dims != None: + canvas = torch.full((3, 1, final_height, final_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img + else: + canvas = torch.full((3, 1, canvas_height, canvas_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, top:top + new_height, left:left + new_width] = ref_img + ref_img = canvas + canvas = None + if return_mask: + if outpainting_dims != None: + canvas = torch.ones((1, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0 + else: + canvas = torch.ones((1, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, top:top + new_height, left:left + new_width] = 0 + canvas = canvas.to(device) + if return_image: + return convert_tensor_to_image(ref_img), canvas + + return ref_img.to(device), canvas + +def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, extract_guide_from_window_start = False, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None ): + src_videos, src_masks = [], [] + inpaint_color = guide_inpaint_color/127.5 - 1 + prepend_count = pre_video_guide.shape[1] if not extract_guide_from_window_start and pre_video_guide is not None else 0 + for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)): + src_video = src_mask = None + if cur_video_guide is not None: + src_video = cur_video_guide.permute(3, 0, 1, 2).float().div_(127.5).sub_(1.) # c, f, h, w + if cur_video_mask is not None and any_mask: + src_mask = cur_video_mask.permute(3, 0, 1, 2).float().div_(255)[0:1] # c, f, h, w + if pre_video_guide is not None and not extract_guide_from_window_start: + src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1) + if any_mask: + src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1) + if src_video is None: + if any_guide_padding: + src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color, dtype = torch.float, device= "cpu") + if any_mask: + src_mask = torch.zeros_like(src_video[0:1]) + elif src_video.shape[1] < current_video_length: + if any_guide_padding: + src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color, dtype = src_video.dtype, device= src_video.device) ], dim=1) + if cur_video_mask is not None and any_mask: + src_mask = torch.cat([src_mask, torch.full( (1, current_video_length - src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1) + else: + new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1 + src_video = src_video[:, :new_num_frames] + if any_mask: + src_mask = src_mask[:, :new_num_frames] + + for k, keep in enumerate(keep_video_guide_frames): + if not keep: + pos = prepend_count + k + src_video[:, pos:pos+1] = inpaint_color + src_mask[:, pos:pos+1] = 1 + + for k, frame in enumerate(inject_frames): + if frame != None: + pos = prepend_count + k + src_video[:, pos:pos+1], src_mask[:, pos:pos+1] = fit_image_into_canvas(frame, image_size, inpaint_color, device, True, outpainting_dims, return_mask= True) + + src_videos.append(src_video) + src_masks.append(src_mask) + return src_videos, src_masks diff --git a/wgp.py b/wgp.py index ee6b5f6..0e6a0b2 100644 --- a/wgp.py +++ b/wgp.py @@ -394,7 +394,7 @@ def process_prompt_and_add_tasks(state, model_choice): gr.Info("MMAudio can generate an Audio track only if the Video is at least 1s long") if "F" in video_prompt_type: if len(frames_positions.strip()) > 0: - positions = frames_positions.split(" ") + positions = frames_positions.replace(","," ").split(" ") for pos_str in positions: if not pos_str in ["L", "l"] and len(pos_str)>0: if not is_integer(pos_str): @@ -2528,7 +2528,7 @@ def download_models(model_filename = None, model_type= None, module_type = False from urllib.request import urlretrieve - from shared.utils.utils import create_progress_hook + from shared.utils.download import create_progress_hook shared_def = { "repoId" : "DeepBeepMeep/Wan2.1", @@ -3726,6 +3726,60 @@ def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canva input_mask = convert_tensor_to_image(full_frame) return input_image, input_mask +def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512): + if not input_video_path or max_frames <= 0: + return None, None + pad_frames = 0 + if start_frame < 0: + pad_frames= -start_frame + max_frames += start_frame + start_frame = 0 + + any_mask = input_mask_path != None + video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps) + if len(video) == 0: return None + if any_mask: + mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps) + frame_height, frame_width, _ = video[0].shape + + num_frames = min(len(video), len(mask_video)) + if num_frames == 0: return None + video, mask_video = video[:num_frames], mask_video[:num_frames] + + from preprocessing.face_preprocessor import FaceProcessor + face_processor = FaceProcessor() + + face_list = [] + for frame_idx in range(num_frames): + frame = video[frame_idx].cpu().numpy() + # video[frame_idx] = None + if any_mask: + mask = Image.fromarray(mask_video[frame_idx].cpu().numpy()) + # mask_video[frame_idx] = None + if (frame_width, frame_height) != mask.size: + mask = mask.resize((frame_width, frame_height), resample=Image.Resampling.LANCZOS) + mask = np.array(mask) + alpha_mask = np.zeros((frame_height, frame_width, 3), dtype=np.uint8) + alpha_mask[mask > 127] = 1 + frame = frame * alpha_mask + frame = Image.fromarray(frame) + face = face_processor.process(frame, resize_to=size, face_crop_scale = 1) + face_list.append(face) + + face_processor = None + gc.collect() + torch.cuda.empty_cache() + + face_tensor= torch.tensor(np.stack(face_list, dtype= np.float32) / 127.5 - 1).permute(-1, 0, 1, 2 ) # t h w c -> c t h w + if pad_frames > 0: + face_tensor= torch.cat([face_tensor[:, -1:].expand(-1, pad_frames, -1, -1), face_tensor ], dim=2) + + if args.save_masks: + from preprocessing.dwpose.pose import save_one_video + saved_faces_frames = [np.array(face) for face in face_list ] + save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None) + return face_tensor + def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1): @@ -3742,7 +3796,13 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, box = [xmin, ymin, xmax, ymax] box = [int(x) for x in box] return box - + inpaint_color = int(inpaint_color) + pad_frames = 0 + if start_frame < 0: + pad_frames= -start_frame + max_frames += start_frame + start_frame = 0 + if not input_video_path or max_frames <= 0: return None, None any_mask = input_mask_path != None @@ -3909,6 +3969,9 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, preproc_outside = None gc.collect() torch.cuda.empty_cache() + if pad_frames > 0: + masked_frames = masked_frames[0] * pad_frames + masked_frames + if any_mask: masked_frames = masks[0] * pad_frames + masks return torch.stack(masked_frames), torch.stack(masks) if any_mask else None @@ -4646,7 +4709,8 @@ def generate_video( current_video_length = video_length # VAE Tiling device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576 - + guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5) + extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False) i2v = test_class_i2v(model_type) diffusion_forcing = "diffusion_forcing" in model_filename t2v = base_model_type in ["t2v"] @@ -4662,6 +4726,7 @@ def generate_video( multitalk = model_def.get("multitalk_class", False) standin = model_def.get("standin_class", False) infinitetalk = base_model_type in ["infinitetalk"] + animate = base_model_type in ["animate"] if "B" in audio_prompt_type or "X" in audio_prompt_type: from models.wan.multitalk.multitalk import parse_speakers_locations @@ -4822,7 +4887,6 @@ def generate_video( repeat_no = 0 extra_generation = 0 initial_total_windows = 0 - discard_last_frames = sliding_window_discard_last_frames default_requested_frames_to_generate = current_video_length if sliding_window: @@ -4843,7 +4907,7 @@ def generate_video( if repeat_no >= total_generation: break repeat_no +=1 gen["repeat_no"] = repeat_no - src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = None + src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = src_faces = None prefix_video = pre_video_frame = None source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before ) @@ -4899,7 +4963,7 @@ def generate_video( return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) ) refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {} - src_ref_images = image_refs + src_ref_images, src_ref_masks = image_refs, None image_start_tensor = image_end_tensor = None if window_no == 1 and (video_source is not None or image_start is not None): if image_start is not None: @@ -4943,16 +5007,52 @@ def generate_video( from models.wan.multitalk.multitalk import get_window_audio_embeddings # special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding) audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length) - if vace: - video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None + + if repeat_no == 1 and window_no == 1 and image_refs is not None and len(image_refs) > 0: + frames_positions_list = [] + if frames_positions is not None and len(frames_positions)> 0: + positions = frames_positions.replace(","," ").split(" ") + cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count) + last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count + joker_used = False + project_window_no = 1 + for pos in positions : + if len(pos) > 0: + if pos in ["L", "l"]: + cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length + if cur_end_pos >= last_frame_no and not joker_used: + joker_used = True + cur_end_pos = last_frame_no -1 + project_window_no += 1 + frames_positions_list.append(cur_end_pos) + cur_end_pos -= sliding_window_discard_last_frames + reuse_frames + else: + frames_positions_list.append(int(pos)-1 + alignment_shift) + frames_positions_list = frames_positions_list[:len(image_refs)] + nb_frames_positions = len(frames_positions_list) + if nb_frames_positions > 0: + frames_to_inject = [None] * (max(frames_positions_list) + 1) + for i, pos in enumerate(frames_positions_list): + frames_to_inject[pos] = image_refs[i] + if video_guide is not None: - keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) + keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) if len(error) > 0: raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") - keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ] + guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame + keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else [] + keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ] + guide_frames_extract_count = len(keep_frames_parsed) - if vace: + if "B" in video_prompt_type: + send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")]) + src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps) + if src_faces is not None and src_faces.shape[1] < current_video_length: + src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1) + + if vace or animate: + video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None context_scale = [ control_net_weight] if "V" in video_prompt_type: process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None) @@ -4971,10 +5071,10 @@ def generate_video( if preprocess_type2 is not None: context_scale = [ control_net_weight /2, control_net_weight2 /2] send_cmd("progress", [0, get_latest_status(state, status_info)]) - inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask=="inpaint" else 127 - video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) , start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color ) + inpaint_color = 0 if preprocess_type=="pose" else guide_inpaint_color + video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color ) if preprocess_type2 != None: - video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 ) + video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 ) if video_guide_processed != None: if sample_fit_canvas != None: @@ -4985,7 +5085,37 @@ def generate_video( refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())] if video_mask_processed != None: refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy()) - elif ltxv: + + frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame] + + if not vace and (any_letters(video_prompt_type ,"FV") or model_def.get("forced_guide_mask_inputs", False)): + any_mask = True + any_guide_padding = model_def.get("pad_guide_video", False) + from shared.utils.utils import prepare_video_guide_and_mask + src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed, video_guide_processed2], + [video_mask_processed, video_mask_processed2], + pre_video_guide, image_size, current_video_length, latent_size, + any_mask, any_guide_padding, guide_inpaint_color, extract_guide_from_window_start, + keep_frames_parsed, frames_to_inject_parsed , outpainting_dims) + + src_video, src_video2 = src_videos + src_mask, src_mask2 = src_masks + if src_video is None: + abort = True + break + if src_faces is not None: + if src_faces.shape[1] < src_video.shape[1]: + src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1) + else: + src_faces = src_faces[:, :src_video.shape[1]] + if args.save_masks: + save_video( src_video, "masked_frames.mp4", fps) + if src_video2 is not None: + save_video( src_video2, "masked_frames2.mp4", fps) + if any_mask: + save_video( src_mask, "masks.mp4", fps, value_range=(0, 1)) + + elif ltxv: preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw") status_info = "Extracting " + processes_names[preprocess_type] send_cmd("progress", [0, get_latest_status(state, status_info)]) @@ -5023,7 +5153,7 @@ def generate_video( sample_fit_canvas = None else: # video to video - video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps) + video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size= block_size) if video_guide_processed is None: src_video = pre_video_guide else: @@ -5043,29 +5173,6 @@ def generate_video( refresh_preview["image_mask"] = new_image_mask if window_no == 1 and image_refs is not None and len(image_refs) > 0: - if repeat_no == 1: - frames_positions_list = [] - if frames_positions is not None and len(frames_positions)> 0: - positions = frames_positions.split(" ") - cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count) #if reset_control_aligment else 0 - last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count - joker_used = False - project_window_no = 1 - for pos in positions : - if len(pos) > 0: - if pos in ["L", "l"]: - cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length - if cur_end_pos >= last_frame_no and not joker_used: - joker_used = True - cur_end_pos = last_frame_no -1 - project_window_no += 1 - frames_positions_list.append(cur_end_pos) - cur_end_pos -= sliding_window_discard_last_frames + reuse_frames - else: - frames_positions_list.append(int(pos)-1 + alignment_shift) - frames_positions_list = frames_positions_list[:len(image_refs)] - nb_frames_positions = len(frames_positions_list) - if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) : from shared.utils.utils import get_outpainting_full_area_dimensions w, h = image_refs[0].size @@ -5089,20 +5196,16 @@ def generate_video( if remove_background_images_ref > 0: send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) # keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested - image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0], + image_refs[nb_frames_positions:], src_ref_masks = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0], remove_background_images_ref > 0, any_background_ref, fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1, block_size=block_size, - outpainting_dims =outpainting_dims ) + outpainting_dims =outpainting_dims, + background_ref_outpainted = model_def.get("background_ref_outpainted", True) ) refresh_preview["image_refs"] = image_refs - if nb_frames_positions > 0: - frames_to_inject = [None] * (max(frames_positions_list) + 1) - for i, pos in enumerate(frames_positions_list): - frames_to_inject[pos] = image_refs[i] if vace : - frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_end_frame] image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2], @@ -5116,7 +5219,7 @@ def generate_video( any_background_ref = any_background_ref ) if len(frames_to_inject_parsed) or any_background_ref: - new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + aligned_guide_start_frame - aligned_window_start_frame) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] + new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + 0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] if any_background_ref: new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:] else: @@ -5165,10 +5268,14 @@ def generate_video( input_prompt = prompt, image_start = image_start_tensor, image_end = image_end_tensor, - input_frames = src_video, + input_frames = src_video, + input_frames2 = src_video2, input_ref_images= src_ref_images, + input_ref_masks = src_ref_masks, input_masks = src_mask, + input_masks2 = src_mask2, input_video= pre_video_guide, + input_faces = src_faces, denoising_strength=denoising_strength, prefix_frames_count = source_video_overlap_frames_count if window_no <= 1 else reuse_frames, frame_num= (current_video_length // latent_size)* latent_size + 1, @@ -5302,6 +5409,7 @@ def generate_video( send_cmd("output") else: sample = samples.cpu() + abort = not is_image and sample.shape[1] < current_video_length # if True: # for testing # torch.save(sample, "output.pt") # else: @@ -6980,7 +7088,7 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): old_video_prompt_type = video_prompt_type - video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUV") + video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUVB") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) visible = "V" in video_prompt_type model_type = state["model_type"] @@ -7437,7 +7545,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non image_prompt_type = gr.Text(value= image_prompt_type_value, visible= False) image_prompt_type_choices = [] if "T" in image_prompt_types_allowed: - image_prompt_type_choices += [("Text Prompt Only", "")] + image_prompt_type_choices += [("Text Prompt Only" if "S" in image_prompt_types_allowed else "New Video", "")] if "S" in image_prompt_types_allowed: image_prompt_type_choices += [("Start Video with Image", "S")] any_start_image = True @@ -7516,7 +7624,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non if image_outputs: video_prompt_type_video_guide_label = video_prompt_type_video_guide_label.replace("Video", "Image") video_prompt_type_video_guide = gr.Dropdown( guide_preprocessing_choices, - value=filter_letters(video_prompt_type_value, "PDESLCMUV", guide_preprocessing.get("default", "") ), + value=filter_letters(video_prompt_type_value, "PDESLCMUVB", guide_preprocessing.get("default", "") ), label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True, ) any_control_video = True @@ -7560,13 +7668,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non } mask_preprocessing_choices = [] - mask_preprocessing_labels = guide_preprocessing.get("labels", {}) + mask_preprocessing_labels = mask_preprocessing.get("labels", {}) for process_type in mask_preprocessing["selection"]: process_label = mask_preprocessing_labels.get(process_type, None) process_label = mask_preprocessing_labels_all.get(process_type, process_type) if process_label is None else process_label mask_preprocessing_choices.append( (process_label, process_type) ) - video_prompt_type_video_mask_label = guide_preprocessing.get("label", "Area Processed") + video_prompt_type_video_mask_label = mask_preprocessing.get("label", "Area Processed") video_prompt_type_video_mask = gr.Dropdown( mask_preprocessing_choices, value=filter_letters(video_prompt_type_value, "XYZWNA", mask_preprocessing.get("default", "")), @@ -7591,7 +7699,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non choices= image_ref_choices["choices"], value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]), visible = image_ref_choices.get("visible", True), - label=image_ref_choices.get("label", "Ref. Images Type"), show_label= True, scale = 2 + label=image_ref_choices.get("label", "Inject Reference Images"), show_label= True, scale = 2 ) image_guide = gr.Image(label= "Control Image", height = 800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None)) @@ -7634,7 +7742,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#") video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False) with gr.Group(): - video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Injected Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) + video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Positioned Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row: video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")] @@ -7649,14 +7757,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) image_refs_single_image_mode = model_def.get("one_image_ref_needed", False) - image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will start a new Clip)" if infinitetalk else "") + image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will be associated to a Sliding Window)" if infinitetalk else "") image_refs_row, image_refs, image_refs_extra = get_image_gallery(label= image_refs_label, value = ui_defaults.get("image_refs", None), visible= "I" in video_prompt_type_value, single_image_mode=image_refs_single_image_mode) frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames (1=first, L=last of a window) no position for other Image Refs)" ) image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs) no_background_removal = model_def.get("no_background_removal", False) or image_ref_choices is None - background_removal_label = model_def.get("background_removal_label", "Remove Backgrounds behind People / Objects") + background_removal_label = model_def.get("background_removal_label", "Remove Background behind People / Objects") remove_background_images_ref = gr.Dropdown( choices=[ @@ -7664,7 +7772,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non (background_removal_label, 1), ], value=0 if no_background_removal else ui_defaults.get("remove_background_images_ref",1), - label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal + label="Automatic Removal of Background behind People or Objects in Reference Images", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal ) any_audio_voices_support = any_audio_track(base_model_type) @@ -8084,7 +8192,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("Aligned to the beginning of the First Window of the new Video Sample", "T"), ], value=filter_letters(video_prompt_type_value, "T"), - label="Control Video / Injected Frames / Control Audio temporal alignment when any Video to continue", + label="Control Video / Control Audio / Positioned Frames Temporal Alignment when any Video to continue", visible = vace or ltxv or t2v or infinitetalk ) From e28c95ae912addee1abdb1b798bd0f554cd658cb Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 23 Sep 2025 23:04:44 +0200 Subject: [PATCH 4/6] Vace Contenders are in Town --- README.md | 8 + defaults/lucy_edit.json | 3 +- defaults/lucy_edit_fastwan.json | 16 ++ models/flux/flux_handler.py | 2 +- models/flux/flux_main.py | 14 +- models/hyvideo/hunyuan.py | 14 +- models/ltx_video/ltxv_handler.py | 2 + models/qwen/qwen_handler.py | 2 +- models/qwen/qwen_main.py | 12 +- models/wan/animate/animate_utils.py | 143 ++++++++++ models/wan/animate/face_blocks.py | 382 +++++++++++++++++++++++++ models/wan/animate/model_animate.py | 31 +++ models/wan/animate/motion_encoder.py | 308 +++++++++++++++++++++ models/wan/any2video.py | 153 +++------- models/wan/df_handler.py | 1 + models/wan/wan_handler.py | 8 +- shared/utils/utils.py | 98 ++++--- wgp.py | 400 ++++++++++----------------- 18 files changed, 1155 insertions(+), 442 deletions(-) create mode 100644 defaults/lucy_edit_fastwan.json create mode 100644 models/wan/animate/animate_utils.py create mode 100644 models/wan/animate/face_blocks.py create mode 100644 models/wan/animate/model_animate.py create mode 100644 models/wan/animate/motion_encoder.py diff --git a/README.md b/README.md index 8c5e436..6c338ac 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,14 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : +### September 23 2025: WanGP v8.7 - Here Are Two New Contenders in the Vace Arena ! + +So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages: +- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion tranfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*. + +- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package. + + ### September 15 2025: WanGP v8.6 - Attack of the Clones - The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**) diff --git a/defaults/lucy_edit.json b/defaults/lucy_edit.json index 6344dff..a8f67ad 100644 --- a/defaults/lucy_edit.json +++ b/defaults/lucy_edit.json @@ -2,7 +2,7 @@ "model": { "name": "Wan2.2 Lucy Edit 5B", "architecture": "lucy_edit", - "description": "Lucy Edit Dev is a video editing model that performs instruction-guided edits on videos using free-text prompts \u2014 it supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.", + "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.", "URLs": [ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors", "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors", @@ -10,6 +10,7 @@ ], "group": "wan2_2" }, + "prompt": "change the clothes to red", "video_length": 81, "guidance_scale": 5, "flow_shift": 5, diff --git a/defaults/lucy_edit_fastwan.json b/defaults/lucy_edit_fastwan.json new file mode 100644 index 0000000..c67c795 --- /dev/null +++ b/defaults/lucy_edit_fastwan.json @@ -0,0 +1,16 @@ +{ + "model": { + "name": "Wan2.2 FastWan Lucy Edit 5B", + "architecture": "lucy_edit", + "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.", + "URLs": "lucy_edit", + "group": "wan2_2", + "loras": "ti2v_2_2" + }, + "prompt": "change the clothes to red", + "video_length": 81, + "guidance_scale": 1, + "flow_shift": 3, + "num_inference_steps": 5, + "resolution": "1280x720" +} \ No newline at end of file diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py index d168881..83de7c3 100644 --- a/models/flux/flux_handler.py +++ b/models/flux/flux_handler.py @@ -56,7 +56,7 @@ class family_handler(): } - extra_model_def["lock_image_refs_ratios"] = True + extra_model_def["fit_into_canvas_image_refs"] = 0 return extra_model_def diff --git a/models/flux/flux_main.py b/models/flux/flux_main.py index 6863711..6746a23 100644 --- a/models/flux/flux_main.py +++ b/models/flux/flux_main.py @@ -142,8 +142,8 @@ class model_factory: n_prompt: str = None, sampling_steps: int = 20, input_ref_images = None, - image_guide= None, - image_mask= None, + input_frames= None, + input_masks= None, width= 832, height=480, embedded_guidance_scale: float = 2.5, @@ -197,10 +197,12 @@ class model_factory: for new_img in input_ref_images[1:]: stiched = stitch_images(stiched, new_img) input_ref_images = [stiched] - elif image_guide is not None: - input_ref_images = [image_guide] + elif input_frames is not None: + input_ref_images = [convert_tensor_to_image(input_frames) ] else: input_ref_images = None + image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True) + if self.name in ['flux-dev-uso', 'flux-dev-umo'] : inp, height, width = prepare_multi_ip( @@ -253,8 +255,8 @@ class model_factory: if image_mask is not None: from shared.utils.utils import convert_image_to_tensor img_msk_rebuilt = inp["img_msk_rebuilt"] - img= convert_image_to_tensor(image_guide) - x = img.squeeze(2) * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt + img= input_frames.squeeze(1).unsqueeze(0) # convert_image_to_tensor(image_guide) + x = img * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt x = x.clamp(-1, 1) x = x.transpose(0, 1) diff --git a/models/hyvideo/hunyuan.py b/models/hyvideo/hunyuan.py index 181a9a7..aa6c3b3 100644 --- a/models/hyvideo/hunyuan.py +++ b/models/hyvideo/hunyuan.py @@ -865,7 +865,7 @@ class HunyuanVideoSampler(Inference): freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict) else: if input_frames != None: - target_height, target_width = input_frames.shape[-3:-1] + target_height, target_width = input_frames.shape[-2:] elif input_video != None: target_height, target_width = input_video.shape[-2:] @@ -894,9 +894,10 @@ class HunyuanVideoSampler(Inference): pixel_value_bg = input_video.unsqueeze(0) pixel_value_mask = torch.zeros_like(input_video).unsqueeze(0) if input_frames != None: - pixel_value_video_bg = input_frames.permute(-1,0,1,2).unsqueeze(0).float() - pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float() - pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.) + pixel_value_video_bg = input_frames.unsqueeze(0) #.permute(-1,0,1,2).unsqueeze(0).float() + # pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.) + # pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float() + pixel_value_video_mask = input_masks.repeat(3,1,1,1).unsqueeze(0) if input_video != None: pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2) pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2) @@ -908,10 +909,11 @@ class HunyuanVideoSampler(Inference): if pixel_value_bg.shape[2] < frame_num: padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] + list(pixel_value_bg.shape[3:]) pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2) - pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2) + # pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2) + pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 1, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2) bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample() - pixel_value_mask = pixel_value_mask.div_(127.5).add_(-1.) + pixel_value_mask = pixel_value_mask.mul_(2).add_(-1.) # unmasked pixels is -1 (no 0 as usual) and masked is 1 mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample() bg_latents = torch.cat([bg_latents, mask_latents], dim=1) bg_latents.mul_(self.vae.config.scaling_factor) diff --git a/models/ltx_video/ltxv_handler.py b/models/ltx_video/ltxv_handler.py index 2845fdb..8c322e1 100644 --- a/models/ltx_video/ltxv_handler.py +++ b/models/ltx_video/ltxv_handler.py @@ -35,6 +35,8 @@ class family_handler(): "selection": ["", "A", "NA", "XA", "XNA"], } + extra_model_def["extra_control_frames"] = 1 + extra_model_def["dont_cat_preguide"]= True return extra_model_def @staticmethod diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py index 010298e..cc6a764 100644 --- a/models/qwen/qwen_handler.py +++ b/models/qwen/qwen_handler.py @@ -17,7 +17,7 @@ class family_handler(): ("Default", "default"), ("Lightning", "lightning")], "guidance_max_phases" : 1, - "lock_image_refs_ratios": True, + "fit_into_canvas_image_refs": 0, } if base_model_type in ["qwen_image_edit_20B"]: diff --git a/models/qwen/qwen_main.py b/models/qwen/qwen_main.py index cec0b8e..abd5c5c 100644 --- a/models/qwen/qwen_main.py +++ b/models/qwen/qwen_main.py @@ -17,7 +17,7 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage from diffusers import FlowMatchEulerDiscreteScheduler from .pipeline_qwenimage import QwenImagePipeline from PIL import Image -from shared.utils.utils import calculate_new_dimensions +from shared.utils.utils import calculate_new_dimensions, convert_tensor_to_image def stitch_images(img1, img2): # Resize img2 to match img1's height @@ -103,8 +103,8 @@ class model_factory(): n_prompt = None, sampling_steps: int = 20, input_ref_images = None, - image_guide= None, - image_mask= None, + input_frames= None, + input_masks= None, width= 832, height=480, guide_scale: float = 4, @@ -179,8 +179,10 @@ class model_factory(): if n_prompt is None or len(n_prompt) == 0: n_prompt= "text, watermark, copyright, blurry, low resolution" - if image_guide is not None: - input_ref_images = [image_guide] + + image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True) + if input_frames is not None: + input_ref_images = [convert_tensor_to_image(input_frames) ] elif input_ref_images is not None: # image stiching method stiched = input_ref_images[0] diff --git a/models/wan/animate/animate_utils.py b/models/wan/animate/animate_utils.py new file mode 100644 index 0000000..9474dce --- /dev/null +++ b/models/wan/animate/animate_utils.py @@ -0,0 +1,143 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import numbers +from peft import LoraConfig + + +def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"): + target_modules = [] + for name, module in transformer.named_modules(): + if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear): + target_modules.append(name) + + transformer_lora_config = LoraConfig( + r=rank, + lora_alpha=alpha, + init_lora_weights=init_lora_weights, + target_modules=target_modules, + ) + return transformer_lora_config + + + +class TensorList(object): + + def __init__(self, tensors): + """ + tensors: a list of torch.Tensor objects. No need to have uniform shape. + """ + assert isinstance(tensors, (list, tuple)) + assert all(isinstance(u, torch.Tensor) for u in tensors) + assert len(set([u.ndim for u in tensors])) == 1 + assert len(set([u.dtype for u in tensors])) == 1 + assert len(set([u.device for u in tensors])) == 1 + self.tensors = tensors + + def to(self, *args, **kwargs): + return TensorList([u.to(*args, **kwargs) for u in self.tensors]) + + def size(self, dim): + assert dim == 0, 'only support get the 0th size' + return len(self.tensors) + + def pow(self, *args, **kwargs): + return TensorList([u.pow(*args, **kwargs) for u in self.tensors]) + + def squeeze(self, dim): + assert dim != 0 + if dim > 0: + dim -= 1 + return TensorList([u.squeeze(dim) for u in self.tensors]) + + def type(self, *args, **kwargs): + return TensorList([u.type(*args, **kwargs) for u in self.tensors]) + + def type_as(self, other): + assert isinstance(other, (torch.Tensor, TensorList)) + if isinstance(other, torch.Tensor): + return TensorList([u.type_as(other) for u in self.tensors]) + else: + return TensorList([u.type(other.dtype) for u in self.tensors]) + + @property + def dtype(self): + return self.tensors[0].dtype + + @property + def device(self): + return self.tensors[0].device + + @property + def ndim(self): + return 1 + self.tensors[0].ndim + + def __getitem__(self, index): + return self.tensors[index] + + def __len__(self): + return len(self.tensors) + + def __add__(self, other): + return self._apply(other, lambda u, v: u + v) + + def __radd__(self, other): + return self._apply(other, lambda u, v: v + u) + + def __sub__(self, other): + return self._apply(other, lambda u, v: u - v) + + def __rsub__(self, other): + return self._apply(other, lambda u, v: v - u) + + def __mul__(self, other): + return self._apply(other, lambda u, v: u * v) + + def __rmul__(self, other): + return self._apply(other, lambda u, v: v * u) + + def __floordiv__(self, other): + return self._apply(other, lambda u, v: u // v) + + def __truediv__(self, other): + return self._apply(other, lambda u, v: u / v) + + def __rfloordiv__(self, other): + return self._apply(other, lambda u, v: v // u) + + def __rtruediv__(self, other): + return self._apply(other, lambda u, v: v / u) + + def __pow__(self, other): + return self._apply(other, lambda u, v: u ** v) + + def __rpow__(self, other): + return self._apply(other, lambda u, v: v ** u) + + def __neg__(self): + return TensorList([-u for u in self.tensors]) + + def __iter__(self): + for tensor in self.tensors: + yield tensor + + def __repr__(self): + return 'TensorList: \n' + repr(self.tensors) + + def _apply(self, other, op): + if isinstance(other, (list, tuple, TensorList)) or ( + isinstance(other, torch.Tensor) and ( + other.numel() > 1 or other.ndim > 1 + ) + ): + assert len(other) == len(self.tensors) + return TensorList([op(u, v) for u, v in zip(self.tensors, other)]) + elif isinstance(other, numbers.Number) or ( + isinstance(other, torch.Tensor) and ( + other.numel() == 1 and other.ndim <= 1 + ) + ): + return TensorList([op(u, other) for u in self.tensors]) + else: + raise TypeError( + f'unsupported operand for *: "TensorList" and "{type(other)}"' + ) \ No newline at end of file diff --git a/models/wan/animate/face_blocks.py b/models/wan/animate/face_blocks.py new file mode 100644 index 0000000..8ddb829 --- /dev/null +++ b/models/wan/animate/face_blocks.py @@ -0,0 +1,382 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from torch import nn +import torch +from typing import Tuple, Optional +from einops import rearrange +import torch.nn.functional as F +import math +from shared.attention import pay_attention + +MEMORY_LAYOUT = { + "flash": ( + lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), + lambda x: x, + ), + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} + + +def attention( + q, + k, + v, + mode="torch", + drop_rate=0, + attn_mask=None, + causal=False, + max_seqlen_q=None, + batch_size=1, +): + """ + Perform QKV self attention. + + Args: + q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. + k (torch.Tensor): Key tensor with shape [b, s1, a, d] + v (torch.Tensor): Value tensor with shape [b, s1, a, d] + mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. + drop_rate (float): Dropout rate in attention map. (default: 0) + attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). + (default: None) + causal (bool): Whether to use causal attention. (default: False) + cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into q. + cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into kv. + max_seqlen_q (int): The maximum sequence length in the batch of q. + max_seqlen_kv (int): The maximum sequence length in the batch of k and v. + + Returns: + torch.Tensor: Output tensor after self attention with shape [b, s, ad] + """ + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + + if mode == "torch": + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + + elif mode == "flash": + x = flash_attn_func( + q, + k, + v, + ) + x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d] + elif mode == "vanilla": + scale_factor = 1 / math.sqrt(q.size(-1)) + + b, a, s, _ = q.shape + s1 = k.size(2) + attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) + if causal: + # Only applied to self attention + assert attn_mask is None, "Causal mask and attn_mask cannot be used together" + temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + attn = (q @ k.transpose(-2, -1)) * scale_factor + attn += attn_bias + attn = attn.softmax(dim=-1) + attn = torch.dropout(attn, p=drop_rate, train=True) + x = attn @ v + else: + raise NotImplementedError(f"Unsupported attention mode: {mode}") + + x = post_attn_layout(x) + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +class CausalConv1d(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + + +class FaceEncoder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1) + self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(1024, 1024, 3, stride=2) + self.conv3 = CausalConv1d(1024, 1024, 3, stride=2) + + self.out_proj = nn.Linear(1024, hidden_dim) + self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + + x = rearrange(x, "b t c -> b c t") + b, c, t = x.shape + + x = self.conv1_local(x) + x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads) + + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv2(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv3(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm3(x) + x = self.act(x) + x = self.out_proj(x) + x = rearrange(x, "(b n) t c -> b t n c", b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + return x_local + + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +class FaceAdapter(nn.Module): + def __init__( + self, + hidden_dim: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + num_adapter_layers: int = 1, + dtype=None, + device=None, + ): + + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.hidden_size = hidden_dim + self.heads_num = heads_num + self.fuser_blocks = nn.ModuleList( + [ + FaceBlock( + self.hidden_size, + self.heads_num, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + **factory_kwargs, + ) + for _ in range(num_adapter_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + motion_embed: torch.Tensor, + idx: int, + freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None, + freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.Tensor: + + return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k) + + + +class FaceBlock(nn.Module): + def __init__( + self, + hidden_size: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + self.scale = qk_scale or head_dim**-0.5 + + self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs) + self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + qk_norm_layer = get_norm_layer(qk_norm_type) + self.q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + def forward( + self, + x: torch.Tensor, + motion_vec: torch.Tensor, + motion_mask: Optional[torch.Tensor] = None, + use_context_parallel=False, + ) -> torch.Tensor: + + B, T, N, C = motion_vec.shape + T_comp = T + + x_motion = self.pre_norm_motion(motion_vec) + x_feat = self.pre_norm_feat(x) + + kv = self.linear1_kv(x_motion) + q = self.linear1_q(x_feat) + + k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num) + q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num) + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + k = rearrange(k, "B L N H D -> (B L) N H D") + v = rearrange(v, "B L N H D -> (B L) N H D") + + if use_context_parallel: + q = gather_forward(q, dim=1) + + q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp) + # Compute attention. + # Size([batches, tokens, heads, head_features]) + qkv_list = [q, k, v] + del q,k,v + attn = pay_attention(qkv_list) + # attn = attention( + # q, + # k, + # v, + # max_seqlen_q=q.shape[1], + # batch_size=q.shape[0], + # ) + + attn = attn.reshape(*attn.shape[:2], -1) + attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp) + # if use_context_parallel: + # attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()] + + output = self.linear2(attn) + + if motion_mask is not None: + output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1) + + return output \ No newline at end of file diff --git a/models/wan/animate/model_animate.py b/models/wan/animate/model_animate.py new file mode 100644 index 0000000..d07f762 --- /dev/null +++ b/models/wan/animate/model_animate.py @@ -0,0 +1,31 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +import types +from copy import deepcopy +from einops import rearrange +from typing import List +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn + +def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values): + pose_latents = self.pose_patch_embedding(pose_latents) + x[:, :, 1:] += pose_latents + + b,c,T,h,w = face_pixel_values.shape + face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w") + encode_bs = 8 + face_pixel_values_tmp = [] + for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)): + face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs])) + + motion_vec = torch.cat(face_pixel_values_tmp) + + motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) + motion_vec = self.face_encoder(motion_vec) + + B, L, H, C = motion_vec.shape + pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + return x, motion_vec diff --git a/models/wan/animate/motion_encoder.py b/models/wan/animate/motion_encoder.py new file mode 100644 index 0000000..02b0040 --- /dev/null +++ b/models/wan/animate/motion_encoder.py @@ -0,0 +1,308 @@ +# Modified from ``https://github.com/wyhsirius/LIA`` +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import torch.nn as nn +from torch.nn import functional as F +import math + +def custom_qr(input_tensor): + original_dtype = input_tensor.dtype + if original_dtype in [torch.bfloat16, torch.float16]: + q, r = torch.linalg.qr(input_tensor.to(torch.float32)) + return q.to(original_dtype), r.to(original_dtype) + return torch.linalg.qr(input_tensor) + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ] + + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) + return out[:, :, ::down_y, ::down_x] + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + if k.ndim == 1: + k = k[None, :] * k[:, None] + k /= k.sum() + return k + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + return upfirdn2d(input, self.kernel, pad=self.pad) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + return F.leaky_relu(input, negative_slope=self.negative_slope) + + +class EqualConv2d(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + else: + self.bias = None + + def forward(self, input): + + return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + else: + out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})') + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, + bias=bias and not activate)) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class EncoderApp(nn.Module): + def __init__(self, size, w_dim=512): + super(EncoderApp, self).__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256, + 128: 128, + 256: 64, + 512: 32, + 1024: 16 + } + + self.w_dim = w_dim + log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channels[size], 1)) + + in_channel = channels[size] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False)) + + def forward(self, x): + + res = [] + h = x + for conv in self.convs: + h = conv(h) + res.append(h) + + return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:] + + +class Encoder(nn.Module): + def __init__(self, size, dim=512, dim_motion=20): + super(Encoder, self).__init__() + + # appearance netmork + self.net_app = EncoderApp(size, dim) + + # motion network + fc = [EqualLinear(dim, dim)] + for i in range(3): + fc.append(EqualLinear(dim, dim)) + + fc.append(EqualLinear(dim, dim_motion)) + self.fc = nn.Sequential(*fc) + + def enc_app(self, x): + h_source = self.net_app(x) + return h_source + + def enc_motion(self, x): + h, _ = self.net_app(x) + h_motion = self.fc(h) + return h_motion + + +class Direction(nn.Module): + def __init__(self, motion_dim): + super(Direction, self).__init__() + self.weight = nn.Parameter(torch.randn(512, motion_dim)) + + def forward(self, input): + + weight = self.weight + 1e-8 + Q, R = custom_qr(weight) + if input is None: + return Q + else: + input_diag = torch.diag_embed(input) # alpha, diagonal matrix + out = torch.matmul(input_diag, Q.T) + out = torch.sum(out, dim=1) + return out + + +class Synthesis(nn.Module): + def __init__(self, motion_dim): + super(Synthesis, self).__init__() + self.direction = Direction(motion_dim) + + +class Generator(nn.Module): + def __init__(self, size, style_dim=512, motion_dim=20): + super().__init__() + + self.enc = Encoder(size, style_dim, motion_dim) + self.dec = Synthesis(motion_dim) + + def get_motion(self, img): + #motion_feat = self.enc.enc_motion(img) + # motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True) + with torch.cuda.amp.autocast(dtype=torch.float32): + motion_feat = self.enc.enc_motion(img) + motion = self.dec.direction(motion_feat) + return motion \ No newline at end of file diff --git a/models/wan/any2video.py b/models/wan/any2video.py index c03752b..41d6d63 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -203,10 +203,7 @@ class WanAny2V: self.use_timestep_transform = True def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = None): - if ref_images is None: - ref_images = [None] * len(frames) - else: - assert len(frames) == len(ref_images) + ref_images = [ref_images] * len(frames) if masks is None: latents = self.vae.encode(frames, tile_size = tile_size) @@ -238,11 +235,7 @@ class WanAny2V: return cat_latents def vace_encode_masks(self, masks, ref_images=None): - if ref_images is None: - ref_images = [None] * len(masks) - else: - assert len(masks) == len(ref_images) - + ref_images = [ref_images] * len(masks) result_masks = [] for mask, refs in zip(masks, ref_images): c, depth, height, width = mask.shape @@ -270,79 +263,6 @@ class WanAny2V: result_masks.append(mask) return result_masks - def vace_latent(self, z, m): - return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] - - - def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False): - image_sizes = [] - trim_video_guide = len(keep_video_guide_frames) - def conv_tensor(t, device): - return t.float().div_(127.5).add_(-1).permute(3, 0, 1, 2).to(device) - - for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)): - prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1] - num_frames = total_frames - prepend_count - num_frames = min(num_frames, trim_video_guide) if trim_video_guide > 0 and sub_src_video != None else num_frames - if sub_src_mask is not None and sub_src_video is not None: - src_video[i] = conv_tensor(sub_src_video[:num_frames], device) - src_mask[i] = conv_tensor(sub_src_mask[:num_frames], device) - # src_video is [-1, 1] (at this function output), 0 = inpainting area (in fact 127 in [0, 255]) - # src_mask is [-1, 1] (at this function output), 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255]) - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) - src_mask[i] = torch.cat( [torch.full_like(sub_pre_src_video, -1.0), src_mask[i]] ,1) - src_video_shape = src_video[i].shape - if src_video_shape[1] != total_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.clamp((src_mask[i][:, :, :, :] + 1) / 2, min=0, max=1) - image_sizes.append(src_video[i].shape[2:]) - elif sub_src_video is None: - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1) - else: - src_video[i] = torch.zeros((3, total_frames, image_size[0], image_size[1]), device=device) - src_mask[i] = torch.ones_like(src_video[i], device=device) - image_sizes.append(image_size) - else: - src_video[i] = conv_tensor(sub_src_video[:num_frames], device) - src_mask[i] = torch.ones_like(src_video[i], device=device) - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) - src_video_shape = src_video[i].shape - if src_video_shape[1] != total_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - image_sizes.append(src_video[i].shape[2:]) - for k, keep in enumerate(keep_video_guide_frames): - if not keep: - pos = prepend_count + k - src_video[i][:, pos:pos+1] = 0 - src_mask[i][:, pos:pos+1] = 1 - - for k, frame in enumerate(inject_frames): - if frame != None: - pos = prepend_count + k - src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True) - - - self.background_mask = None - for i, ref_images in enumerate(src_ref_images): - if ref_images is not None: - image_size = image_sizes[i] - for j, ref_img in enumerate(ref_images): - if ref_img is not None and not torch.is_tensor(ref_img): - if j==0 and any_background_ref: - if self.background_mask == None: self.background_mask = [None] * len(src_ref_images) - src_ref_images[i][j], self.background_mask[i] = fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True) - else: - src_ref_images[i][j], _ = fit_image_into_canvas(ref_img, image_size, 1, device) - if self.background_mask != None: - self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref - return src_video, src_mask, src_ref_images def get_vae_latents(self, ref_images, device, tile_size= 0): ref_vae_latents = [] @@ -369,7 +289,9 @@ class WanAny2V: def generate(self, input_prompt, input_frames= None, + input_frames2= None, input_masks = None, + input_masks2 = None, input_ref_images = None, input_ref_masks = None, input_faces = None, @@ -615,21 +537,22 @@ class WanAny2V: pose_pixels = input_frames * input_masks input_masks = 1. - input_masks pose_pixels -= input_masks - save_video(pose_pixels, "pose.mp4") pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0) input_frames = input_frames * input_masks if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames if prefix_frames_count > 0: input_frames[:, :prefix_frames_count] = input_video input_masks[:, :prefix_frames_count] = 1 - save_video(input_frames, "input_frames.mp4") - save_video(input_masks, "input_masks.mp4", value_range=(0,1)) + # save_video(pose_pixels, "pose.mp4") + # save_video(input_frames, "input_frames.mp4") + # save_video(input_masks, "input_masks.mp4", value_range=(0,1)) lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device) msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device) msk = torch.concat([msk_ref, msk_control], dim=1) - clip_image_start = image_ref = convert_image_to_tensor(input_ref_images[0]).to(self.device) - lat_y = torch.concat(self.vae.encode([image_ref.unsqueeze(1).to(self.device), input_frames.to(self.device)], VAE_tile_size), dim=1) + image_ref = input_ref_images[0].to(self.device) + clip_image_start = image_ref.squeeze(1) + lat_y = torch.concat(self.vae.encode([image_ref, input_frames.to(self.device)], VAE_tile_size), dim=1) y = torch.concat([msk, lat_y]) kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)}) lat_y = msk = msk_control = msk_ref = pose_pixels = None @@ -701,12 +624,11 @@ class WanAny2V: # Phantom if phantom: - input_ref_images_neg = None - if input_ref_images != None: # Phantom Ref images - input_ref_images = self.get_vae_latents(input_ref_images, self.device) - input_ref_images_neg = torch.zeros_like(input_ref_images) - ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0 - trim_frames = input_ref_images.shape[1] + lat_input_ref_images_neg = None + if input_ref_images is not None: # Phantom Ref images + lat_input_ref_images = self.get_vae_latents(input_ref_images, self.device) + lat_input_ref_images_neg = torch.zeros_like(lat_input_ref_images) + ref_images_count = trim_frames = lat_input_ref_images.shape[1] if ti2v: if input_video is None: @@ -721,25 +643,23 @@ class WanAny2V: # Vace if vace : # vace context encode - input_frames = [u.to(self.device) for u in input_frames] - input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] - input_masks = [u.to(self.device) for u in input_masks] + input_frames = [input_frames.to(self.device)] +([] if input_frames2 is None else [input_frames2.to(self.device)]) + input_masks = [input_masks.to(self.device)] + ([] if input_masks2 is None else [input_masks2.to(self.device)]) + input_ref_images = None if input_ref_images is None else [ u.to(self.device) for u in input_ref_images] + input_ref_masks = None if input_ref_masks is None else [ None if u is None else u.to(self.device) for u in input_ref_masks] ref_images_before = True - if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask] z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents ) m0 = self.vace_encode_masks(input_masks, input_ref_images) - if self.background_mask != None: - color_reference_frame = input_ref_images[0][0].clone() - zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size ) - mbg = self.vace_encode_masks(self.background_mask, None) + if input_ref_masks is not None and len(input_ref_masks) > 0 and input_ref_masks[0] is not None: + color_reference_frame = input_ref_images[0].clone() + zbg = self.vace_encode_frames( input_ref_images[:1] * len(input_frames), None, masks=input_ref_masks[0], tile_size = VAE_tile_size ) + mbg = self.vace_encode_masks(input_ref_masks[:1] * len(input_frames), None) for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg): zz0[:, 0:1] = zzbg mm0[:, 0:1] = mmbg - - self.background_mask = zz0 = mm0 = zzbg = mmbg = None - z = self.vace_latent(z0, m0) - - ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0 + zz0 = mm0 = zzbg = mmbg = None + z = [torch.cat([zz, mm], dim=0) for zz, mm in zip(z0, m0)] + ref_images_count = len(input_ref_images) if input_ref_images is not None and input_ref_images is not None else 0 context_scale = context_scale if context_scale != None else [1.0] * len(z) kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count }) if overlapped_latents != None : @@ -747,15 +667,8 @@ class WanAny2V: extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0) if prefix_frames_count > 0: color_reference_frame = input_frames[0][:, prefix_frames_count -1:prefix_frames_count].clone() - - target_shape = list(z0[0].shape) - target_shape[0] = int(target_shape[0] / 2) - lat_h, lat_w = target_shape[-2:] - height = self.vae_stride[1] * lat_h - width = self.vae_stride[2] * lat_w - - else: - target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2]) + lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] + target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, lat_h, lat_w) if multitalk: if audio_proj is None: @@ -860,7 +773,9 @@ class WanAny2V: apg_norm_threshold = 55 text_momentumbuffer = MomentumBuffer(apg_momentum) audio_momentumbuffer = MomentumBuffer(apg_momentum) - + input_frames = input_frames2 = input_masks =input_masks2 = input_video = input_ref_images = input_ref_masks = pre_video_frame = None + gc.collect() + torch.cuda.empty_cache() # denoising trans = self.model @@ -878,7 +793,7 @@ class WanAny2V: kwargs.update({"t": timestep, "current_step": start_step_no + i}) kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None - if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step: + if denoising_strength < 1 and i <= injection_denoising_step: sigma = t / 1000 noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) if inject_from_start: @@ -912,8 +827,8 @@ class WanAny2V: any_guidance = guide_scale != 1 if phantom: gen_args = { - "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 + - [ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]), + "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 + + [ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]), "context": [context, context_null, context_null] , } elif fantasy: diff --git a/models/wan/df_handler.py b/models/wan/df_handler.py index 7a5d3ea..31d5da6 100644 --- a/models/wan/df_handler.py +++ b/models/wan/df_handler.py @@ -21,6 +21,7 @@ class family_handler(): extra_model_def["fps"] =fps extra_model_def["frames_minimum"] = 17 extra_model_def["frames_steps"] = 20 + extra_model_def["latent_size"] = 4 extra_model_def["sliding_window"] = True extra_model_def["skip_layer_guidance"] = True extra_model_def["tea_cache"] = True diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index ff1570f..f2a8fb4 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -114,7 +114,6 @@ class family_handler(): "tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels), "mag_cache" : True, "keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"], - "convert_image_guide_to_video" : True, "sample_solvers":[ ("unipc", "unipc"), ("euler", "euler"), @@ -175,6 +174,8 @@ class family_handler(): extra_model_def["forced_guide_mask_inputs"] = True extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)" extra_model_def["background_ref_outpainted"] = False + extra_model_def["return_image_refs_tensor"] = True + extra_model_def["guide_inpaint_color"] = 0 @@ -196,15 +197,15 @@ class family_handler(): "letters_filter": "KFI", } - extra_model_def["lock_image_refs_ratios"] = True extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames" extra_model_def["video_guide_outpainting"] = [0,1] extra_model_def["pad_guide_video"] = True extra_model_def["guide_inpaint_color"] = 127.5 extra_model_def["forced_guide_mask_inputs"] = True + extra_model_def["return_image_refs_tensor"] = True if base_model_type in ["standin"]: - extra_model_def["lock_image_refs_ratios"] = True + extra_model_def["fit_into_canvas_image_refs"] = 0 extra_model_def["image_ref_choices"] = { "choices": [ ("No Reference Image", ""), @@ -480,6 +481,7 @@ class family_handler(): ui_defaults.update({ "video_prompt_type": "PVBXAKI", "mask_expand": 20, + "audio_prompt_type_value": "R", }) if text_oneframe_overlap(base_model_type): diff --git a/shared/utils/utils.py b/shared/utils/utils.py index 6e8a98b..bb2d5ff 100644 --- a/shared/utils/utils.py +++ b/shared/utils/utils.py @@ -32,6 +32,14 @@ def seed_everything(seed: int): if torch.backends.mps.is_available(): torch.mps.manual_seed(seed) +def has_video_file_extension(filename): + extension = os.path.splitext(filename)[-1].lower() + return extension in [".mp4"] + +def has_image_file_extension(filename): + extension = os.path.splitext(filename)[-1].lower() + return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"] + def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ): import math @@ -94,7 +102,7 @@ def get_video_info(video_path): return fps, width, height, frame_count -def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, return_PIL = True) -> torch.Tensor: +def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, target_fps = None, return_PIL = True) -> torch.Tensor: """Extract nth frame from video as PyTorch tensor normalized to [-1, 1].""" cap = cv2.VideoCapture(file_name) @@ -102,7 +110,10 @@ def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool raise ValueError(f"Cannot open video: {file_name}") total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - + fps = round(cap.get(cv2.CAP_PROP_FPS)) + if target_fps is not None: + frame_no = round(target_fps * frame_no /fps) + # Handle out of bounds if frame_no >= total_frames or frame_no < 0: if return_last_if_missing: @@ -175,10 +186,15 @@ def remove_background(img, session=None): def convert_image_to_tensor(image): return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0) -def convert_tensor_to_image(t, frame_no = 0): +def convert_tensor_to_image(t, frame_no = 0, mask_levels = False): if len(t.shape) == 4: t = t[:, frame_no] - return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) + if t.shape[0]== 1: + t = t.expand(3,-1,-1) + if mask_levels: + return Image.fromarray(t.clone().mul_(255).permute(1,2,0).to(torch.uint8).cpu().numpy()) + else: + return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) def save_image(tensor_image, name, frame_no = -1): convert_tensor_to_image(tensor_image, frame_no).save(name) @@ -257,7 +273,7 @@ def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fi image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) return image, new_height, new_width -def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5 ): +def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5, return_tensor = False ): if rm_background: session = new_session() @@ -266,7 +282,7 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg for i, img in enumerate(img_list): width, height = img.size resized_mask = None - if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2: + if any_background_ref == 1 and i==0 or any_background_ref == 2: if outpainting_dims is not None and background_ref_outpainted: resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True) elif img.size != (budget_width, budget_height): @@ -291,7 +307,10 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) : # resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') - output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200, + if return_tensor: + output_list.append(convert_image_to_tensor(resized_image).unsqueeze(1)) + else: + output_list.append(resized_image) output_mask_list.append(resized_mask) return output_list, output_mask_list @@ -346,47 +365,46 @@ def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu return ref_img.to(device), canvas -def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, extract_guide_from_window_start = False, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None ): +def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None, device ="cpu"): src_videos, src_masks = [], [] - inpaint_color = guide_inpaint_color/127.5 - 1 - prepend_count = pre_video_guide.shape[1] if not extract_guide_from_window_start and pre_video_guide is not None else 0 + inpaint_color_compressed = guide_inpaint_color/127.5 - 1 + prepend_count = pre_video_guide.shape[1] if pre_video_guide is not None else 0 for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)): - src_video = src_mask = None - if cur_video_guide is not None: - src_video = cur_video_guide.permute(3, 0, 1, 2).float().div_(127.5).sub_(1.) # c, f, h, w - if cur_video_mask is not None and any_mask: - src_mask = cur_video_mask.permute(3, 0, 1, 2).float().div_(255)[0:1] # c, f, h, w - if pre_video_guide is not None and not extract_guide_from_window_start: + src_video, src_mask = cur_video_guide, cur_video_mask + if pre_video_guide is not None: src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1) if any_mask: src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1) - if src_video is None: - if any_guide_padding: - src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color, dtype = torch.float, device= "cpu") - if any_mask: - src_mask = torch.zeros_like(src_video[0:1]) - elif src_video.shape[1] < current_video_length: - if any_guide_padding: - src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color, dtype = src_video.dtype, device= src_video.device) ], dim=1) - if cur_video_mask is not None and any_mask: - src_mask = torch.cat([src_mask, torch.full( (1, current_video_length - src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1) + + if any_guide_padding: + if src_video is None: + src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color_compressed, dtype = torch.float, device= device) + elif src_video.shape[1] < current_video_length: + src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color_compressed, dtype = src_video.dtype, device= src_video.device) ], dim=1) + elif src_video is not None: + new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1 + src_video = src_video[:, :new_num_frames] + + if any_mask and src_video is not None: + if src_mask is None: + src_mask = torch.ones_like(src_video[:1]) + elif src_mask.shape[1] < src_video.shape[1]: + src_mask = torch.cat([src_mask, torch.full( (1, src_video.shape[1]- src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1) else: - new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1 - src_video = src_video[:, :new_num_frames] - if any_mask: - src_mask = src_mask[:, :new_num_frames] + src_mask = src_mask[:, :src_video.shape[1]] - for k, keep in enumerate(keep_video_guide_frames): - if not keep: - pos = prepend_count + k - src_video[:, pos:pos+1] = inpaint_color - src_mask[:, pos:pos+1] = 1 - - for k, frame in enumerate(inject_frames): - if frame != None: - pos = prepend_count + k - src_video[:, pos:pos+1], src_mask[:, pos:pos+1] = fit_image_into_canvas(frame, image_size, inpaint_color, device, True, outpainting_dims, return_mask= True) + if src_video is not None : + for k, keep in enumerate(keep_video_guide_frames): + if not keep: + pos = prepend_count + k + src_video[:, pos:pos+1] = inpaint_color_compressed + if any_mask: src_mask[:, pos:pos+1] = 1 + for k, frame in enumerate(inject_frames): + if frame != None: + pos = prepend_count + k + src_video[:, pos:pos+1], msk = fit_image_into_canvas(frame, image_size, guide_inpaint_color, device, True, outpainting_dims, return_mask= any_mask) + if any_mask: src_mask[:, pos:pos+1] = msk src_videos.append(src_video) src_masks.append(src_mask) return src_videos, src_masks diff --git a/wgp.py b/wgp.py index 0e6a0b2..6a7bca1 100644 --- a/wgp.py +++ b/wgp.py @@ -24,6 +24,7 @@ from shared.utils import notification_sound from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background, rgb_bw_to_rgba_mask from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions +from shared.utils.utils import has_video_file_extension, has_image_file_extension from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image from shared.utils.audio_video import save_image_metadata, read_image_metadata from shared.match_archi import match_nvidia_architecture @@ -62,7 +63,7 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.61" +WanGP_version = "8.7" settings_version = 2.35 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -1942,7 +1943,8 @@ def get_model_min_frames_and_step(model_type): mode_def = get_model_def(model_type) frames_minimum = mode_def.get("frames_minimum", 5) frames_steps = mode_def.get("frames_steps", 4) - return frames_minimum, frames_steps + latent_size = mode_def.get("latent_size", frames_steps) + return frames_minimum, frames_steps, latent_size def get_model_fps(model_type): mode_def = get_model_def(model_type) @@ -3459,7 +3461,7 @@ def select_video(state, input_file_list, event_data: gr.EventData): if len(video_other_prompts) >0 : values += [video_other_prompts] labels += ["Other Prompts"] - if len(video_outpainting) >0 and any_letters(video_image_prompt_type, "VFK"): + if len(video_outpainting) >0: values += [video_outpainting] labels += ["Outpainting"] video_sample_solver = configs.get("sample_solver", "") @@ -3532,6 +3534,11 @@ def convert_image(image): return cast(Image, ImageOps.exif_transpose(image)) def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='torch'): + if isinstance(video_in, str) and has_image_file_extension(video_in): + video_in = Image.open(video_in) + if isinstance(video_in, Image.Image): + return torch.from_numpy(np.array(video_in).astype(np.uint8)).unsqueeze(0) + from shared.utils.utils import resample import decord @@ -3653,19 +3660,22 @@ def get_preprocessor(process_type, inpaint_color): def process_images_multithread(image_processor, items, process_type, wrap_in_list = True, max_workers: int = os.cpu_count()/ 2) : if not items: return [] - max_workers = 11 + import concurrent.futures start_time = time.time() # print(f"Preprocessus:{process_type} started") if process_type in ["prephase", "upsample"]: if wrap_in_list : items = [ [img] for img in items] - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)} - results = [None] * len(items) - for future in concurrent.futures.as_completed(futures): - idx = futures[future] - results[idx] = future.result() + if max_workers == 1: + results = [image_processor(img) for img in items] + else: + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)} + results = [None] * len(items) + for future in concurrent.futures.as_completed(futures): + idx = futures[future] + results[idx] = future.result() if wrap_in_list: results = [ img[0] for img in results] @@ -3677,55 +3687,6 @@ def process_images_multithread(image_processor, items, process_type, wrap_in_lis return results -def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canvas = False, fit_crop = False, block_size= 16, expand_scale = 2, outpainting_dims = None, inpaint_color = 127): - frame_width, frame_height = input_image.size - - if fit_crop: - input_image = rescale_and_crop(input_image, width, height) - if input_mask is not None: - input_mask = rescale_and_crop(input_mask, width, height) - return input_image, input_mask - - if outpainting_dims != None: - if fit_canvas != None: - frame_height, frame_width = get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims) - else: - frame_height, frame_width = height, width - - if fit_canvas != None: - height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size) - - if outpainting_dims != None: - final_height, final_width = height, width - height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1) - - if fit_canvas != None or outpainting_dims != None: - input_image = input_image.resize((width, height), resample=Image.Resampling.LANCZOS) - if input_mask is not None: - input_mask = input_mask.resize((width, height), resample=Image.Resampling.LANCZOS) - - if expand_scale != 0 and input_mask is not None: - kernel_size = abs(expand_scale) - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) - op_expand = cv2.dilate if expand_scale > 0 else cv2.erode - input_mask = np.array(input_mask) - input_mask = op_expand(input_mask, kernel, iterations=3) - input_mask = Image.fromarray(input_mask) - - if outpainting_dims != None: - inpaint_color = inpaint_color / 127.5-1 - image = convert_image_to_tensor(input_image) - full_frame= torch.full( (image.shape[0], final_height, final_width), inpaint_color, dtype= torch.float, device= image.device) - full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = image - input_image = convert_tensor_to_image(full_frame) - - if input_mask is not None: - mask = convert_image_to_tensor(input_mask) - full_frame= torch.full( (mask.shape[0], final_height, final_width), 1, dtype= torch.float, device= mask.device) - full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = mask - input_mask = convert_tensor_to_image(full_frame) - - return input_image, input_mask def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512): if not input_video_path or max_frames <= 0: return None, None @@ -3780,6 +3741,8 @@ def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_fr save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None) return face_tensor +def get_default_workers(): + return os.cpu_count()/ 2 def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1): @@ -3906,8 +3869,8 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, return (target_frame, frame, mask) else: return (target_frame, None, None) - - proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False) + max_workers = get_default_workers() + proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False, max_workers=max_workers) proc_list, proc_list_outside, proc_mask = [None] * len(proc_lists), [None] * len(proc_lists), [None] * len(proc_lists) for frame_idx, frame_group in enumerate(proc_lists): proc_list[frame_idx], proc_list_outside[frame_idx], proc_mask[frame_idx] = frame_group @@ -3916,11 +3879,11 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, mask_video = None if preproc2 != None: - proc_list2 = process_images_multithread(preproc2, proc_list, process_type2) + proc_list2 = process_images_multithread(preproc2, proc_list, process_type2, max_workers=max_workers) #### to be finished ...or not - proc_list = process_images_multithread(preproc, proc_list, process_type) + proc_list = process_images_multithread(preproc, proc_list, process_type, max_workers=max_workers) if any_mask: - proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask) + proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask, max_workers=max_workers) else: proc_list_outside = proc_mask = len(proc_list) * [None] @@ -3938,7 +3901,7 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, full_frame= torch.full( (final_height, final_width, mask.shape[-1]), 255, dtype= torch.uint8, device= mask.device) full_frame[margin_top:margin_top+height, margin_left:margin_left+width] = mask mask = full_frame - masks.append(mask) + masks.append(mask[:, :, 0:1].clone()) else: masked_frame = processed_img @@ -3958,13 +3921,13 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, proc_list[frame_no] = proc_list_outside[frame_no] = proc_mask[frame_no] = None - if args.save_masks: - from preprocessing.dwpose.pose import save_one_video - saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ] - save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None) - if any_mask: - saved_masks = [mask.cpu().numpy() for mask in masks ] - save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None) + # if args.save_masks: + # from preprocessing.dwpose.pose import save_one_video + # saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ] + # save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None) + # if any_mask: + # saved_masks = [mask.cpu().numpy() for mask in masks ] + # save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None) preproc = None preproc_outside = None gc.collect() @@ -3972,8 +3935,10 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, if pad_frames > 0: masked_frames = masked_frames[0] * pad_frames + masked_frames if any_mask: masked_frames = masks[0] * pad_frames + masks + masked_frames = torch.stack(masked_frames).permute(-1,0,1,2).float().div_(127.5).sub_(1.) + masks = torch.stack(masks).permute(-1,0,1,2).float().div_(255) if any_mask else None - return torch.stack(masked_frames), torch.stack(masks) if any_mask else None + return masked_frames, masks def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size = 16): @@ -4102,7 +4067,7 @@ def perform_spatial_upsampling(sample, spatial_upsampling): frames_to_upsample = [sample[:, i] for i in range( sample.shape[1]) ] def upsample_frames(frame): return resize_lanczos(frame, h, w).unsqueeze(1) - sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False), dim=1) + sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False, max_workers=get_default_workers()), dim=1) frames_to_upsample = None return sample @@ -4609,17 +4574,13 @@ def generate_video( batch_size = 1 temp_filenames_list = [] - convert_image_guide_to_video = model_def.get("convert_image_guide_to_video", False) - if convert_image_guide_to_video: - if image_guide is not None and isinstance(image_guide, Image.Image): - video_guide = convert_image_to_video(image_guide) - temp_filenames_list.append(video_guide) - image_guide = None + if image_guide is not None and isinstance(image_guide, Image.Image): + video_guide = image_guide + image_guide = None - if image_mask is not None and isinstance(image_mask, Image.Image): - video_mask = convert_image_to_video(image_mask) - temp_filenames_list.append(video_mask) - image_mask = None + if image_mask is not None and isinstance(image_mask, Image.Image): + video_mask = image_mask + image_mask = None if model_def.get("no_background_removal", False): remove_background_images_ref = 0 @@ -4711,22 +4672,12 @@ def generate_video( device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576 guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5) extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False) - i2v = test_class_i2v(model_type) - diffusion_forcing = "diffusion_forcing" in model_filename - t2v = base_model_type in ["t2v"] - ltxv = "ltxv" in model_filename - vace = test_vace_module(base_model_type) - hunyuan_t2v = "hunyuan_video_720" in model_filename - hunyuan_i2v = "hunyuan_video_i2v" in model_filename hunyuan_custom = "hunyuan_video_custom" in model_filename hunyuan_custom_audio = hunyuan_custom and "audio" in model_filename hunyuan_custom_edit = hunyuan_custom and "edit" in model_filename hunyuan_avatar = "hunyuan_video_avatar" in model_filename fantasy = base_model_type in ["fantasy"] multitalk = model_def.get("multitalk_class", False) - standin = model_def.get("standin_class", False) - infinitetalk = base_model_type in ["infinitetalk"] - animate = base_model_type in ["animate"] if "B" in audio_prompt_type or "X" in audio_prompt_type: from models.wan.multitalk.multitalk import parse_speakers_locations @@ -4763,9 +4714,9 @@ def generate_video( sliding_window_size = current_video_length reuse_frames = 0 - _, latent_size = get_model_min_frames_and_step(model_type) - if diffusion_forcing: latent_size = 4 + _, _, latent_size = get_model_min_frames_and_step(model_type) original_image_refs = image_refs + image_refs = None if image_refs is None else [] + image_refs # work on a copy as it is going to be modified # image_refs = None # nb_frames_positions= 0 # Output Video Ratio Priorities: @@ -4889,6 +4840,7 @@ def generate_video( initial_total_windows = 0 discard_last_frames = sliding_window_discard_last_frames default_requested_frames_to_generate = current_video_length + nb_frames_positions = 0 if sliding_window: initial_total_windows= compute_sliding_window_no(default_requested_frames_to_generate, sliding_window_size, discard_last_frames, reuse_frames) current_video_length = sliding_window_size @@ -4907,7 +4859,7 @@ def generate_video( if repeat_no >= total_generation: break repeat_no +=1 gen["repeat_no"] = repeat_no - src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = src_faces = None + src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = None prefix_video = pre_video_frame = None source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before ) @@ -4963,7 +4915,6 @@ def generate_video( return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) ) refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {} - src_ref_images, src_ref_masks = image_refs, None image_start_tensor = image_end_tensor = None if window_no == 1 and (video_source is not None or image_start is not None): if image_start is not None: @@ -5020,7 +4971,7 @@ def generate_video( if len(pos) > 0: if pos in ["L", "l"]: cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length - if cur_end_pos >= last_frame_no and not joker_used: + if cur_end_pos >= last_frame_no-1 and not joker_used: joker_used = True cur_end_pos = last_frame_no -1 project_window_no += 1 @@ -5036,141 +4987,53 @@ def generate_video( frames_to_inject[pos] = image_refs[i] + video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None if video_guide is not None: keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) if len(error) > 0: raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame + extra_control_frames = model_def.get("extra_control_frames", 0) + if extra_control_frames > 0 and aligned_guide_start_frame >= extra_control_frames: guide_frames_extract_start -= extra_control_frames + keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else [] keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ] guide_frames_extract_count = len(keep_frames_parsed) + # Extract Faces to video if "B" in video_prompt_type: send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")]) src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps) if src_faces is not None and src_faces.shape[1] < current_video_length: src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1) - if vace or animate: - video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None - context_scale = [ control_net_weight] - if "V" in video_prompt_type: - process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None) - preprocess_type, preprocess_type2 = "raw", None - for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")): - if process_num == 0: - preprocess_type = process_map_video_guide.get(process_letter, "raw") - else: - preprocess_type2 = process_map_video_guide.get(process_letter, None) - status_info = "Extracting " + processes_names[preprocess_type] - extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask]) - if len(extra_process_list) == 1: - status_info += " and " + processes_names[extra_process_list[0]] - elif len(extra_process_list) == 2: - status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]] - if preprocess_type2 is not None: - context_scale = [ control_net_weight /2, control_net_weight2 /2] - send_cmd("progress", [0, get_latest_status(state, status_info)]) - inpaint_color = 0 if preprocess_type=="pose" else guide_inpaint_color - video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color ) - if preprocess_type2 != None: - video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 ) + # Sparse Video to Video + sparse_video_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, target_fps = fps, return_PIL = True) if "R" in video_prompt_type else None - if video_guide_processed != None: - if sample_fit_canvas != None: - image_size = video_guide_processed.shape[-3: -1] - sample_fit_canvas = None - refresh_preview["video_guide"] = Image.fromarray(video_guide_processed[0].cpu().numpy()) - if video_guide_processed2 != None: - refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())] - if video_mask_processed != None: - refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy()) - - frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame] - - if not vace and (any_letters(video_prompt_type ,"FV") or model_def.get("forced_guide_mask_inputs", False)): - any_mask = True - any_guide_padding = model_def.get("pad_guide_video", False) - from shared.utils.utils import prepare_video_guide_and_mask - src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed, video_guide_processed2], - [video_mask_processed, video_mask_processed2], - pre_video_guide, image_size, current_video_length, latent_size, - any_mask, any_guide_padding, guide_inpaint_color, extract_guide_from_window_start, - keep_frames_parsed, frames_to_inject_parsed , outpainting_dims) - - src_video, src_video2 = src_videos - src_mask, src_mask2 = src_masks - if src_video is None: - abort = True - break - if src_faces is not None: - if src_faces.shape[1] < src_video.shape[1]: - src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1) - else: - src_faces = src_faces[:, :src_video.shape[1]] - if args.save_masks: - save_video( src_video, "masked_frames.mp4", fps) - if src_video2 is not None: - save_video( src_video2, "masked_frames2.mp4", fps) - if any_mask: - save_video( src_mask, "masks.mp4", fps, value_range=(0, 1)) - - elif ltxv: - preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw") - status_info = "Extracting " + processes_names[preprocess_type] - send_cmd("progress", [0, get_latest_status(state, status_info)]) - # start one frame ealier to facilitate latents merging later - src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size ) - if src_video != None: - src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ] - refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) - refresh_preview["video_mask"] = None - src_video = src_video.permute(3, 0, 1, 2) - src_video = src_video.float().div_(127.5).sub_(1.) # c, f, h, w - if sample_fit_canvas != None: - image_size = src_video.shape[-2:] - sample_fit_canvas = None - - elif hunyuan_custom_edit: - if "P" in video_prompt_type: - progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")] + # Generic Video Preprocessing + process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None) + preprocess_type, preprocess_type2 = "raw", None + for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")): + if process_num == 0: + preprocess_type = process_map_video_guide.get(process_letter, "raw") else: - progress_args = [0, get_latest_status(state,"Extracting Video and Mask")] + preprocess_type2 = process_map_video_guide.get(process_letter, None) + status_info = "Extracting " + processes_names[preprocess_type] + extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask]) + if len(extra_process_list) == 1: + status_info += " and " + processes_names[extra_process_list[0]] + elif len(extra_process_list) == 2: + status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]] + context_scale = [control_net_weight /2, control_net_weight2 /2] if preprocess_type2 is not None else [control_net_weight] + if not (preprocess_type == "identity" and preprocess_type2 is None and video_mask is None):send_cmd("progress", [0, get_latest_status(state, status_info)]) + inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask == "inpaint" else guide_inpaint_color + video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide if sparse_video_image is None else sparse_video_image, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color, block_size = block_size ) + if preprocess_type2 != None: + video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2, block_size = block_size ) - send_cmd("progress", progress_args) - src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0) - refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) - if src_mask != None: - refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy()) - - elif "R" in video_prompt_type: # sparse video to video - src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True) - src_image, _, _ = calculate_dimensions_and_resize_image(src_image, image_size[0], image_size[1 ], sample_fit_canvas, fit_crop, block_size = block_size) - refresh_preview["video_guide"] = src_image - src_video = convert_image_to_tensor(src_image).unsqueeze(1) - if sample_fit_canvas != None: - image_size = src_video.shape[-2:] - sample_fit_canvas = None - - else: # video to video - video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size= block_size) - if video_guide_processed is None: - src_video = pre_video_guide - else: - if sample_fit_canvas != None: - image_size = video_guide_processed.shape[-3: -1] - sample_fit_canvas = None - src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2) - if pre_video_guide != None: - src_video = torch.cat( [pre_video_guide, src_video], dim=1) - elif image_guide is not None: - new_image_guide, new_image_mask = preprocess_image_with_mask(image_guide, image_mask, image_size[0], image_size[1], fit_canvas = sample_fit_canvas, fit_crop= fit_crop, block_size= block_size, expand_scale = mask_expand, outpainting_dims=outpainting_dims) - if sample_fit_canvas is not None: - image_size = (new_image_guide.size[1], new_image_guide.size[0]) + if video_guide_processed is not None and sample_fit_canvas is not None: + image_size = video_guide_processed.shape[-2:] sample_fit_canvas = None - refresh_preview["image_guide"] = new_image_guide - if new_image_mask is not None: - refresh_preview["image_mask"] = new_image_mask if window_no == 1 and image_refs is not None and len(image_refs) > 0: if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) : @@ -5192,45 +5055,68 @@ def generate_video( image_refs[i] = rescale_and_crop(img, default_image_size[1], default_image_size[0]) refresh_preview["image_refs"] = image_refs - if len(image_refs) > nb_frames_positions: + if len(image_refs) > nb_frames_positions: + src_ref_images = image_refs[nb_frames_positions:] if remove_background_images_ref > 0: send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) - # keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested - image_refs[nb_frames_positions:], src_ref_masks = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0], + + src_ref_images, src_ref_masks = resize_and_remove_background(src_ref_images , image_size[1], image_size[0], remove_background_images_ref > 0, any_background_ref, - fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1, + fit_into_canvas= model_def.get("fit_into_canvas_image_refs", 1), block_size=block_size, outpainting_dims =outpainting_dims, - background_ref_outpainted = model_def.get("background_ref_outpainted", True) ) - refresh_preview["image_refs"] = image_refs + background_ref_outpainted = model_def.get("background_ref_outpainted", True), + return_tensor= model_def.get("return_image_refs_tensor", False) ) + - - if vace : - image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications - - src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2], - [video_mask_processed] if video_guide_processed2 == None else [video_mask_processed, video_mask_processed2], - [image_refs_copy] if video_guide_processed2 == None else [image_refs_copy, image_refs_copy], - current_video_length, image_size = image_size, device ="cpu", - keep_video_guide_frames=keep_frames_parsed, - pre_src_video = [pre_video_guide] if video_guide_processed2 == None else [pre_video_guide, pre_video_guide], - inject_frames= frames_to_inject_parsed, - outpainting_dims = outpainting_dims, - any_background_ref = any_background_ref - ) - if len(frames_to_inject_parsed) or any_background_ref: - new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + 0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] - if any_background_ref: - new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:] + frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame] + if video_guide is not None or len(frames_to_inject_parsed) > 0 or model_def.get("forced_guide_mask_inputs", False): + any_mask = video_mask is not None or model_def.get("forced_guide_mask_inputs", False) + any_guide_padding = model_def.get("pad_guide_video", False) + from shared.utils.utils import prepare_video_guide_and_mask + src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed] + ([] if video_guide_processed2 is None else [video_guide_processed2]), + [video_mask_processed] + ([] if video_mask_processed2 is None else [video_mask_processed2]), + None if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else pre_video_guide, + image_size, current_video_length, latent_size, + any_mask, any_guide_padding, guide_inpaint_color, + keep_frames_parsed, frames_to_inject_parsed , outpainting_dims) + video_guide_processed = video_guide_processed2 = video_mask_processed = video_mask_processed2 = None + if len(src_videos) == 1: + src_video, src_video2, src_mask, src_mask2 = src_videos[0], None, src_masks[0], None + else: + src_video, src_video2 = src_videos + src_mask, src_mask2 = src_masks + src_videos = src_masks = None + if src_video is None: + abort = True + break + if src_faces is not None: + if src_faces.shape[1] < src_video.shape[1]: + src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1) else: - new_image_refs += image_refs[nb_frames_positions:] - refresh_preview["image_refs"] = new_image_refs - new_image_refs = None - - if sample_fit_canvas != None: - image_size = src_video[0].shape[-2:] - sample_fit_canvas = None + src_faces = src_faces[:, :src_video.shape[1]] + if video_guide is not None or len(frames_to_inject_parsed) > 0: + if args.save_masks: + if src_video is not None: save_video( src_video, "masked_frames.mp4", fps) + if src_video2 is not None: save_video( src_video2, "masked_frames2.mp4", fps) + if any_mask: save_video( src_mask, "masks.mp4", fps, value_range=(0, 1)) + if video_guide is not None: + preview_frame_no = 0 if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else (guide_start_frame - window_start_frame) + refresh_preview["video_guide"] = convert_tensor_to_image(src_video, preview_frame_no) + if src_video2 is not None: + refresh_preview["video_guide"] = [refresh_preview["video_guide"], convert_tensor_to_image(src_video2, preview_frame_no)] + if src_mask is not None and video_mask is not None: + refresh_preview["video_mask"] = convert_tensor_to_image(src_mask, preview_frame_no, mask_levels = True) + if src_ref_images is not None or nb_frames_positions: + if len(frames_to_inject_parsed): + new_image_refs = [convert_tensor_to_image(src_video, frame_no + (0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame)) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] + else: + new_image_refs = [] + if src_ref_images is not None: + new_image_refs += [convert_tensor_to_image(img) if torch.is_tensor(img) else img for img in src_ref_images ] + refresh_preview["image_refs"] = new_image_refs + new_image_refs = None if len(refresh_preview) > 0: new_inputs= locals() @@ -5339,8 +5225,6 @@ def generate_video( pre_video_frame = pre_video_frame, original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [], image_refs_relative_size = image_refs_relative_size, - image_guide= new_image_guide, - image_mask= new_image_mask, outpainting_dims = outpainting_dims, ) except Exception as e: @@ -6320,8 +6204,11 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None pop += ["image_refs_relative_size"] if not vace: - pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2"] + pop += ["frames_positions", "control_net_weight", "control_net_weight2"] + if model_def.get("video_guide_outpainting", None) is None: + pop += ["video_guide_outpainting"] + if not (vace or t2v): pop += ["min_frames_if_references"] @@ -6506,13 +6393,6 @@ def eject_video_from_gallery(state, input_file_list, choice): choice = min(choice, len(file_list)) return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0) -def has_video_file_extension(filename): - extension = os.path.splitext(filename)[-1].lower() - return extension in [".mp4"] - -def has_image_file_extension(filename): - extension = os.path.splitext(filename)[-1].lower() - return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"] def add_videos_to_gallery(state, input_file_list, choice, files_to_load): gen = get_gen_info(state) if files_to_load == None: @@ -7881,7 +7761,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non elif recammaster: video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", get_max_frames(81)), step=4, label="Number of frames (16 = 1s), locked", interactive= False, visible = True) else: - min_frames, frames_step = get_model_min_frames_and_step(base_model_type) + min_frames, frames_step, _ = get_model_min_frames_and_step(base_model_type) current_video_length = ui_defaults.get("video_length", 81 if get_model_family(base_model_type)=="wan" else 97) @@ -8059,7 +7939,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non MMAudio_neg_prompt = gr.Text(ui_defaults.get("MMAudio_neg_prompt", ""), label="Negative Prompt (1 or 2 keywords)") - with gr.Column(visible = (t2v or vace) and not fantasy) as audio_prompt_type_remux_row: + with gr.Column(visible = any_control_video) as audio_prompt_type_remux_row: gr.Markdown("You may transfer the existing audio tracks of a Control Video") audio_prompt_type_remux = gr.Dropdown( choices=[ @@ -8284,16 +8164,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_info = gr.HTML(visible=True, min_height=100, value=get_default_video_info()) with gr.Row(**default_visibility) as video_buttons_row: video_info_extract_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm") - video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video ) video_info_to_video_source_btn = gr.Button("To Video Source", min_width= 1, size ="sm", visible = any_video_source) + video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video ) video_info_eject_video_btn = gr.Button("Eject Video", min_width= 1, size ="sm") with gr.Row(**default_visibility) as image_buttons_row: video_info_extract_image_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm") video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", min_width= 1, visible = any_start_image ) video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", min_width= 1, visible = any_end_image) - video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image ) video_info_to_image_mask_btn = gr.Button("To Mask Image", min_width= 1, size ="sm", visible = any_image_mask and False) video_info_to_reference_image_btn = gr.Button("To Reference Image", min_width= 1, size ="sm", visible = any_reference_image) + video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image ) video_info_eject_image_btn = gr.Button("Eject Image", min_width= 1, size ="sm") with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab: with gr.Group(elem_classes= "postprocess"): From 356e10ce71dc969d11eec42f05a0dfdfada20ad4 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 23 Sep 2025 23:28:43 +0200 Subject: [PATCH 5/6] fixed default sound --- models/wan/wan_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index f2a8fb4..12ddfed 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -481,7 +481,7 @@ class family_handler(): ui_defaults.update({ "video_prompt_type": "PVBXAKI", "mask_expand": 20, - "audio_prompt_type_value": "R", + "audio_prompt_type": "R", }) if text_oneframe_overlap(base_model_type): From 625b50aefd6cb48f09da5bb102fbe900acb9be84 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Wed, 24 Sep 2025 00:12:22 +0200 Subject: [PATCH 6/6] fixed lucy edit fast wan, lora missing --- README.md | 2 ++ defaults/lucy_edit_fastwan.json | 2 +- wgp.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6c338ac..8cd8e9c 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,8 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages: - **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion tranfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*. +In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* embedded tool, to extract the Video Mask of the person from which you want to extract the motion. + - **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package. diff --git a/defaults/lucy_edit_fastwan.json b/defaults/lucy_edit_fastwan.json index c67c795..d5d47c8 100644 --- a/defaults/lucy_edit_fastwan.json +++ b/defaults/lucy_edit_fastwan.json @@ -5,7 +5,7 @@ "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.", "URLs": "lucy_edit", "group": "wan2_2", - "loras": "ti2v_2_2" + "loras": "ti2v_2_2_fastwan" }, "prompt": "change the clothes to red", "video_length": 81, diff --git a/wgp.py b/wgp.py index 6a7bca1..43f5e80 100644 --- a/wgp.py +++ b/wgp.py @@ -63,7 +63,7 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.7" +WanGP_version = "8.71" settings_version = 2.35 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None