diff --git a/README.md b/README.md index 8c5e436..6c338ac 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,14 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : +### September 23 2025: WanGP v8.7 - Here Are Two New Contenders in the Vace Arena ! + +So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages: +- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion tranfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*. + +- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package. + + ### September 15 2025: WanGP v8.6 - Attack of the Clones - The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**) diff --git a/defaults/lucy_edit.json b/defaults/lucy_edit.json index 6344dff..a8f67ad 100644 --- a/defaults/lucy_edit.json +++ b/defaults/lucy_edit.json @@ -2,7 +2,7 @@ "model": { "name": "Wan2.2 Lucy Edit 5B", "architecture": "lucy_edit", - "description": "Lucy Edit Dev is a video editing model that performs instruction-guided edits on videos using free-text prompts \u2014 it supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.", + "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.", "URLs": [ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors", "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors", @@ -10,6 +10,7 @@ ], "group": "wan2_2" }, + "prompt": "change the clothes to red", "video_length": 81, "guidance_scale": 5, "flow_shift": 5, diff --git a/defaults/lucy_edit_fastwan.json b/defaults/lucy_edit_fastwan.json new file mode 100644 index 0000000..c67c795 --- /dev/null +++ b/defaults/lucy_edit_fastwan.json @@ -0,0 +1,16 @@ +{ + "model": { + "name": "Wan2.2 FastWan Lucy Edit 5B", + "architecture": "lucy_edit", + "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.", + "URLs": "lucy_edit", + "group": "wan2_2", + "loras": "ti2v_2_2" + }, + "prompt": "change the clothes to red", + "video_length": 81, + "guidance_scale": 1, + "flow_shift": 3, + "num_inference_steps": 5, + "resolution": "1280x720" +} \ No newline at end of file diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py index d168881..83de7c3 100644 --- a/models/flux/flux_handler.py +++ b/models/flux/flux_handler.py @@ -56,7 +56,7 @@ class family_handler(): } - extra_model_def["lock_image_refs_ratios"] = True + extra_model_def["fit_into_canvas_image_refs"] = 0 return extra_model_def diff --git a/models/flux/flux_main.py b/models/flux/flux_main.py index 6863711..6746a23 100644 --- a/models/flux/flux_main.py +++ b/models/flux/flux_main.py @@ -142,8 +142,8 @@ class model_factory: n_prompt: str = None, sampling_steps: int = 20, input_ref_images = None, - image_guide= None, - image_mask= None, + input_frames= None, + input_masks= None, width= 832, height=480, embedded_guidance_scale: float = 2.5, @@ -197,10 +197,12 @@ class model_factory: for new_img in input_ref_images[1:]: stiched = stitch_images(stiched, new_img) input_ref_images = [stiched] - elif image_guide is not None: - input_ref_images = [image_guide] + elif input_frames is not None: + input_ref_images = [convert_tensor_to_image(input_frames) ] else: input_ref_images = None + image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True) + if self.name in ['flux-dev-uso', 'flux-dev-umo'] : inp, height, width = prepare_multi_ip( @@ -253,8 +255,8 @@ class model_factory: if image_mask is not None: from shared.utils.utils import convert_image_to_tensor img_msk_rebuilt = inp["img_msk_rebuilt"] - img= convert_image_to_tensor(image_guide) - x = img.squeeze(2) * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt + img= input_frames.squeeze(1).unsqueeze(0) # convert_image_to_tensor(image_guide) + x = img * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt x = x.clamp(-1, 1) x = x.transpose(0, 1) diff --git a/models/hyvideo/hunyuan.py b/models/hyvideo/hunyuan.py index 181a9a7..aa6c3b3 100644 --- a/models/hyvideo/hunyuan.py +++ b/models/hyvideo/hunyuan.py @@ -865,7 +865,7 @@ class HunyuanVideoSampler(Inference): freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict) else: if input_frames != None: - target_height, target_width = input_frames.shape[-3:-1] + target_height, target_width = input_frames.shape[-2:] elif input_video != None: target_height, target_width = input_video.shape[-2:] @@ -894,9 +894,10 @@ class HunyuanVideoSampler(Inference): pixel_value_bg = input_video.unsqueeze(0) pixel_value_mask = torch.zeros_like(input_video).unsqueeze(0) if input_frames != None: - pixel_value_video_bg = input_frames.permute(-1,0,1,2).unsqueeze(0).float() - pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float() - pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.) + pixel_value_video_bg = input_frames.unsqueeze(0) #.permute(-1,0,1,2).unsqueeze(0).float() + # pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.) + # pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float() + pixel_value_video_mask = input_masks.repeat(3,1,1,1).unsqueeze(0) if input_video != None: pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2) pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2) @@ -908,10 +909,11 @@ class HunyuanVideoSampler(Inference): if pixel_value_bg.shape[2] < frame_num: padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] + list(pixel_value_bg.shape[3:]) pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2) - pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2) + # pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2) + pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 1, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2) bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample() - pixel_value_mask = pixel_value_mask.div_(127.5).add_(-1.) + pixel_value_mask = pixel_value_mask.mul_(2).add_(-1.) # unmasked pixels is -1 (no 0 as usual) and masked is 1 mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample() bg_latents = torch.cat([bg_latents, mask_latents], dim=1) bg_latents.mul_(self.vae.config.scaling_factor) diff --git a/models/ltx_video/ltxv_handler.py b/models/ltx_video/ltxv_handler.py index 2845fdb..8c322e1 100644 --- a/models/ltx_video/ltxv_handler.py +++ b/models/ltx_video/ltxv_handler.py @@ -35,6 +35,8 @@ class family_handler(): "selection": ["", "A", "NA", "XA", "XNA"], } + extra_model_def["extra_control_frames"] = 1 + extra_model_def["dont_cat_preguide"]= True return extra_model_def @staticmethod diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py index 010298e..cc6a764 100644 --- a/models/qwen/qwen_handler.py +++ b/models/qwen/qwen_handler.py @@ -17,7 +17,7 @@ class family_handler(): ("Default", "default"), ("Lightning", "lightning")], "guidance_max_phases" : 1, - "lock_image_refs_ratios": True, + "fit_into_canvas_image_refs": 0, } if base_model_type in ["qwen_image_edit_20B"]: diff --git a/models/qwen/qwen_main.py b/models/qwen/qwen_main.py index cec0b8e..abd5c5c 100644 --- a/models/qwen/qwen_main.py +++ b/models/qwen/qwen_main.py @@ -17,7 +17,7 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage from diffusers import FlowMatchEulerDiscreteScheduler from .pipeline_qwenimage import QwenImagePipeline from PIL import Image -from shared.utils.utils import calculate_new_dimensions +from shared.utils.utils import calculate_new_dimensions, convert_tensor_to_image def stitch_images(img1, img2): # Resize img2 to match img1's height @@ -103,8 +103,8 @@ class model_factory(): n_prompt = None, sampling_steps: int = 20, input_ref_images = None, - image_guide= None, - image_mask= None, + input_frames= None, + input_masks= None, width= 832, height=480, guide_scale: float = 4, @@ -179,8 +179,10 @@ class model_factory(): if n_prompt is None or len(n_prompt) == 0: n_prompt= "text, watermark, copyright, blurry, low resolution" - if image_guide is not None: - input_ref_images = [image_guide] + + image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True) + if input_frames is not None: + input_ref_images = [convert_tensor_to_image(input_frames) ] elif input_ref_images is not None: # image stiching method stiched = input_ref_images[0] diff --git a/models/wan/animate/animate_utils.py b/models/wan/animate/animate_utils.py new file mode 100644 index 0000000..9474dce --- /dev/null +++ b/models/wan/animate/animate_utils.py @@ -0,0 +1,143 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import numbers +from peft import LoraConfig + + +def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"): + target_modules = [] + for name, module in transformer.named_modules(): + if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear): + target_modules.append(name) + + transformer_lora_config = LoraConfig( + r=rank, + lora_alpha=alpha, + init_lora_weights=init_lora_weights, + target_modules=target_modules, + ) + return transformer_lora_config + + + +class TensorList(object): + + def __init__(self, tensors): + """ + tensors: a list of torch.Tensor objects. No need to have uniform shape. + """ + assert isinstance(tensors, (list, tuple)) + assert all(isinstance(u, torch.Tensor) for u in tensors) + assert len(set([u.ndim for u in tensors])) == 1 + assert len(set([u.dtype for u in tensors])) == 1 + assert len(set([u.device for u in tensors])) == 1 + self.tensors = tensors + + def to(self, *args, **kwargs): + return TensorList([u.to(*args, **kwargs) for u in self.tensors]) + + def size(self, dim): + assert dim == 0, 'only support get the 0th size' + return len(self.tensors) + + def pow(self, *args, **kwargs): + return TensorList([u.pow(*args, **kwargs) for u in self.tensors]) + + def squeeze(self, dim): + assert dim != 0 + if dim > 0: + dim -= 1 + return TensorList([u.squeeze(dim) for u in self.tensors]) + + def type(self, *args, **kwargs): + return TensorList([u.type(*args, **kwargs) for u in self.tensors]) + + def type_as(self, other): + assert isinstance(other, (torch.Tensor, TensorList)) + if isinstance(other, torch.Tensor): + return TensorList([u.type_as(other) for u in self.tensors]) + else: + return TensorList([u.type(other.dtype) for u in self.tensors]) + + @property + def dtype(self): + return self.tensors[0].dtype + + @property + def device(self): + return self.tensors[0].device + + @property + def ndim(self): + return 1 + self.tensors[0].ndim + + def __getitem__(self, index): + return self.tensors[index] + + def __len__(self): + return len(self.tensors) + + def __add__(self, other): + return self._apply(other, lambda u, v: u + v) + + def __radd__(self, other): + return self._apply(other, lambda u, v: v + u) + + def __sub__(self, other): + return self._apply(other, lambda u, v: u - v) + + def __rsub__(self, other): + return self._apply(other, lambda u, v: v - u) + + def __mul__(self, other): + return self._apply(other, lambda u, v: u * v) + + def __rmul__(self, other): + return self._apply(other, lambda u, v: v * u) + + def __floordiv__(self, other): + return self._apply(other, lambda u, v: u // v) + + def __truediv__(self, other): + return self._apply(other, lambda u, v: u / v) + + def __rfloordiv__(self, other): + return self._apply(other, lambda u, v: v // u) + + def __rtruediv__(self, other): + return self._apply(other, lambda u, v: v / u) + + def __pow__(self, other): + return self._apply(other, lambda u, v: u ** v) + + def __rpow__(self, other): + return self._apply(other, lambda u, v: v ** u) + + def __neg__(self): + return TensorList([-u for u in self.tensors]) + + def __iter__(self): + for tensor in self.tensors: + yield tensor + + def __repr__(self): + return 'TensorList: \n' + repr(self.tensors) + + def _apply(self, other, op): + if isinstance(other, (list, tuple, TensorList)) or ( + isinstance(other, torch.Tensor) and ( + other.numel() > 1 or other.ndim > 1 + ) + ): + assert len(other) == len(self.tensors) + return TensorList([op(u, v) for u, v in zip(self.tensors, other)]) + elif isinstance(other, numbers.Number) or ( + isinstance(other, torch.Tensor) and ( + other.numel() == 1 and other.ndim <= 1 + ) + ): + return TensorList([op(u, other) for u in self.tensors]) + else: + raise TypeError( + f'unsupported operand for *: "TensorList" and "{type(other)}"' + ) \ No newline at end of file diff --git a/models/wan/animate/face_blocks.py b/models/wan/animate/face_blocks.py new file mode 100644 index 0000000..8ddb829 --- /dev/null +++ b/models/wan/animate/face_blocks.py @@ -0,0 +1,382 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from torch import nn +import torch +from typing import Tuple, Optional +from einops import rearrange +import torch.nn.functional as F +import math +from shared.attention import pay_attention + +MEMORY_LAYOUT = { + "flash": ( + lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), + lambda x: x, + ), + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} + + +def attention( + q, + k, + v, + mode="torch", + drop_rate=0, + attn_mask=None, + causal=False, + max_seqlen_q=None, + batch_size=1, +): + """ + Perform QKV self attention. + + Args: + q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. + k (torch.Tensor): Key tensor with shape [b, s1, a, d] + v (torch.Tensor): Value tensor with shape [b, s1, a, d] + mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. + drop_rate (float): Dropout rate in attention map. (default: 0) + attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). + (default: None) + causal (bool): Whether to use causal attention. (default: False) + cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into q. + cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into kv. + max_seqlen_q (int): The maximum sequence length in the batch of q. + max_seqlen_kv (int): The maximum sequence length in the batch of k and v. + + Returns: + torch.Tensor: Output tensor after self attention with shape [b, s, ad] + """ + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + + if mode == "torch": + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + + elif mode == "flash": + x = flash_attn_func( + q, + k, + v, + ) + x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d] + elif mode == "vanilla": + scale_factor = 1 / math.sqrt(q.size(-1)) + + b, a, s, _ = q.shape + s1 = k.size(2) + attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) + if causal: + # Only applied to self attention + assert attn_mask is None, "Causal mask and attn_mask cannot be used together" + temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + attn = (q @ k.transpose(-2, -1)) * scale_factor + attn += attn_bias + attn = attn.softmax(dim=-1) + attn = torch.dropout(attn, p=drop_rate, train=True) + x = attn @ v + else: + raise NotImplementedError(f"Unsupported attention mode: {mode}") + + x = post_attn_layout(x) + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +class CausalConv1d(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + + +class FaceEncoder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1) + self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(1024, 1024, 3, stride=2) + self.conv3 = CausalConv1d(1024, 1024, 3, stride=2) + + self.out_proj = nn.Linear(1024, hidden_dim) + self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + + x = rearrange(x, "b t c -> b c t") + b, c, t = x.shape + + x = self.conv1_local(x) + x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads) + + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv2(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv3(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm3(x) + x = self.act(x) + x = self.out_proj(x) + x = rearrange(x, "(b n) t c -> b t n c", b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + return x_local + + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +class FaceAdapter(nn.Module): + def __init__( + self, + hidden_dim: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + num_adapter_layers: int = 1, + dtype=None, + device=None, + ): + + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.hidden_size = hidden_dim + self.heads_num = heads_num + self.fuser_blocks = nn.ModuleList( + [ + FaceBlock( + self.hidden_size, + self.heads_num, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + **factory_kwargs, + ) + for _ in range(num_adapter_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + motion_embed: torch.Tensor, + idx: int, + freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None, + freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.Tensor: + + return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k) + + + +class FaceBlock(nn.Module): + def __init__( + self, + hidden_size: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + self.scale = qk_scale or head_dim**-0.5 + + self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs) + self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + qk_norm_layer = get_norm_layer(qk_norm_type) + self.q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + def forward( + self, + x: torch.Tensor, + motion_vec: torch.Tensor, + motion_mask: Optional[torch.Tensor] = None, + use_context_parallel=False, + ) -> torch.Tensor: + + B, T, N, C = motion_vec.shape + T_comp = T + + x_motion = self.pre_norm_motion(motion_vec) + x_feat = self.pre_norm_feat(x) + + kv = self.linear1_kv(x_motion) + q = self.linear1_q(x_feat) + + k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num) + q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num) + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + k = rearrange(k, "B L N H D -> (B L) N H D") + v = rearrange(v, "B L N H D -> (B L) N H D") + + if use_context_parallel: + q = gather_forward(q, dim=1) + + q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp) + # Compute attention. + # Size([batches, tokens, heads, head_features]) + qkv_list = [q, k, v] + del q,k,v + attn = pay_attention(qkv_list) + # attn = attention( + # q, + # k, + # v, + # max_seqlen_q=q.shape[1], + # batch_size=q.shape[0], + # ) + + attn = attn.reshape(*attn.shape[:2], -1) + attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp) + # if use_context_parallel: + # attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()] + + output = self.linear2(attn) + + if motion_mask is not None: + output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1) + + return output \ No newline at end of file diff --git a/models/wan/animate/model_animate.py b/models/wan/animate/model_animate.py new file mode 100644 index 0000000..d07f762 --- /dev/null +++ b/models/wan/animate/model_animate.py @@ -0,0 +1,31 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +import types +from copy import deepcopy +from einops import rearrange +from typing import List +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn + +def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values): + pose_latents = self.pose_patch_embedding(pose_latents) + x[:, :, 1:] += pose_latents + + b,c,T,h,w = face_pixel_values.shape + face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w") + encode_bs = 8 + face_pixel_values_tmp = [] + for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)): + face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs])) + + motion_vec = torch.cat(face_pixel_values_tmp) + + motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) + motion_vec = self.face_encoder(motion_vec) + + B, L, H, C = motion_vec.shape + pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + return x, motion_vec diff --git a/models/wan/animate/motion_encoder.py b/models/wan/animate/motion_encoder.py new file mode 100644 index 0000000..02b0040 --- /dev/null +++ b/models/wan/animate/motion_encoder.py @@ -0,0 +1,308 @@ +# Modified from ``https://github.com/wyhsirius/LIA`` +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import torch.nn as nn +from torch.nn import functional as F +import math + +def custom_qr(input_tensor): + original_dtype = input_tensor.dtype + if original_dtype in [torch.bfloat16, torch.float16]: + q, r = torch.linalg.qr(input_tensor.to(torch.float32)) + return q.to(original_dtype), r.to(original_dtype) + return torch.linalg.qr(input_tensor) + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ] + + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) + return out[:, :, ::down_y, ::down_x] + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + if k.ndim == 1: + k = k[None, :] * k[:, None] + k /= k.sum() + return k + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + return upfirdn2d(input, self.kernel, pad=self.pad) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + return F.leaky_relu(input, negative_slope=self.negative_slope) + + +class EqualConv2d(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + else: + self.bias = None + + def forward(self, input): + + return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + else: + out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})') + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, + bias=bias and not activate)) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class EncoderApp(nn.Module): + def __init__(self, size, w_dim=512): + super(EncoderApp, self).__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256, + 128: 128, + 256: 64, + 512: 32, + 1024: 16 + } + + self.w_dim = w_dim + log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channels[size], 1)) + + in_channel = channels[size] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False)) + + def forward(self, x): + + res = [] + h = x + for conv in self.convs: + h = conv(h) + res.append(h) + + return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:] + + +class Encoder(nn.Module): + def __init__(self, size, dim=512, dim_motion=20): + super(Encoder, self).__init__() + + # appearance netmork + self.net_app = EncoderApp(size, dim) + + # motion network + fc = [EqualLinear(dim, dim)] + for i in range(3): + fc.append(EqualLinear(dim, dim)) + + fc.append(EqualLinear(dim, dim_motion)) + self.fc = nn.Sequential(*fc) + + def enc_app(self, x): + h_source = self.net_app(x) + return h_source + + def enc_motion(self, x): + h, _ = self.net_app(x) + h_motion = self.fc(h) + return h_motion + + +class Direction(nn.Module): + def __init__(self, motion_dim): + super(Direction, self).__init__() + self.weight = nn.Parameter(torch.randn(512, motion_dim)) + + def forward(self, input): + + weight = self.weight + 1e-8 + Q, R = custom_qr(weight) + if input is None: + return Q + else: + input_diag = torch.diag_embed(input) # alpha, diagonal matrix + out = torch.matmul(input_diag, Q.T) + out = torch.sum(out, dim=1) + return out + + +class Synthesis(nn.Module): + def __init__(self, motion_dim): + super(Synthesis, self).__init__() + self.direction = Direction(motion_dim) + + +class Generator(nn.Module): + def __init__(self, size, style_dim=512, motion_dim=20): + super().__init__() + + self.enc = Encoder(size, style_dim, motion_dim) + self.dec = Synthesis(motion_dim) + + def get_motion(self, img): + #motion_feat = self.enc.enc_motion(img) + # motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True) + with torch.cuda.amp.autocast(dtype=torch.float32): + motion_feat = self.enc.enc_motion(img) + motion = self.dec.direction(motion_feat) + return motion \ No newline at end of file diff --git a/models/wan/any2video.py b/models/wan/any2video.py index c03752b..41d6d63 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -203,10 +203,7 @@ class WanAny2V: self.use_timestep_transform = True def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = None): - if ref_images is None: - ref_images = [None] * len(frames) - else: - assert len(frames) == len(ref_images) + ref_images = [ref_images] * len(frames) if masks is None: latents = self.vae.encode(frames, tile_size = tile_size) @@ -238,11 +235,7 @@ class WanAny2V: return cat_latents def vace_encode_masks(self, masks, ref_images=None): - if ref_images is None: - ref_images = [None] * len(masks) - else: - assert len(masks) == len(ref_images) - + ref_images = [ref_images] * len(masks) result_masks = [] for mask, refs in zip(masks, ref_images): c, depth, height, width = mask.shape @@ -270,79 +263,6 @@ class WanAny2V: result_masks.append(mask) return result_masks - def vace_latent(self, z, m): - return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] - - - def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False): - image_sizes = [] - trim_video_guide = len(keep_video_guide_frames) - def conv_tensor(t, device): - return t.float().div_(127.5).add_(-1).permute(3, 0, 1, 2).to(device) - - for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)): - prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1] - num_frames = total_frames - prepend_count - num_frames = min(num_frames, trim_video_guide) if trim_video_guide > 0 and sub_src_video != None else num_frames - if sub_src_mask is not None and sub_src_video is not None: - src_video[i] = conv_tensor(sub_src_video[:num_frames], device) - src_mask[i] = conv_tensor(sub_src_mask[:num_frames], device) - # src_video is [-1, 1] (at this function output), 0 = inpainting area (in fact 127 in [0, 255]) - # src_mask is [-1, 1] (at this function output), 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255]) - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) - src_mask[i] = torch.cat( [torch.full_like(sub_pre_src_video, -1.0), src_mask[i]] ,1) - src_video_shape = src_video[i].shape - if src_video_shape[1] != total_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.clamp((src_mask[i][:, :, :, :] + 1) / 2, min=0, max=1) - image_sizes.append(src_video[i].shape[2:]) - elif sub_src_video is None: - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1) - else: - src_video[i] = torch.zeros((3, total_frames, image_size[0], image_size[1]), device=device) - src_mask[i] = torch.ones_like(src_video[i], device=device) - image_sizes.append(image_size) - else: - src_video[i] = conv_tensor(sub_src_video[:num_frames], device) - src_mask[i] = torch.ones_like(src_video[i], device=device) - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) - src_video_shape = src_video[i].shape - if src_video_shape[1] != total_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - image_sizes.append(src_video[i].shape[2:]) - for k, keep in enumerate(keep_video_guide_frames): - if not keep: - pos = prepend_count + k - src_video[i][:, pos:pos+1] = 0 - src_mask[i][:, pos:pos+1] = 1 - - for k, frame in enumerate(inject_frames): - if frame != None: - pos = prepend_count + k - src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True) - - - self.background_mask = None - for i, ref_images in enumerate(src_ref_images): - if ref_images is not None: - image_size = image_sizes[i] - for j, ref_img in enumerate(ref_images): - if ref_img is not None and not torch.is_tensor(ref_img): - if j==0 and any_background_ref: - if self.background_mask == None: self.background_mask = [None] * len(src_ref_images) - src_ref_images[i][j], self.background_mask[i] = fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True) - else: - src_ref_images[i][j], _ = fit_image_into_canvas(ref_img, image_size, 1, device) - if self.background_mask != None: - self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref - return src_video, src_mask, src_ref_images def get_vae_latents(self, ref_images, device, tile_size= 0): ref_vae_latents = [] @@ -369,7 +289,9 @@ class WanAny2V: def generate(self, input_prompt, input_frames= None, + input_frames2= None, input_masks = None, + input_masks2 = None, input_ref_images = None, input_ref_masks = None, input_faces = None, @@ -615,21 +537,22 @@ class WanAny2V: pose_pixels = input_frames * input_masks input_masks = 1. - input_masks pose_pixels -= input_masks - save_video(pose_pixels, "pose.mp4") pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0) input_frames = input_frames * input_masks if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames if prefix_frames_count > 0: input_frames[:, :prefix_frames_count] = input_video input_masks[:, :prefix_frames_count] = 1 - save_video(input_frames, "input_frames.mp4") - save_video(input_masks, "input_masks.mp4", value_range=(0,1)) + # save_video(pose_pixels, "pose.mp4") + # save_video(input_frames, "input_frames.mp4") + # save_video(input_masks, "input_masks.mp4", value_range=(0,1)) lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device) msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device) msk = torch.concat([msk_ref, msk_control], dim=1) - clip_image_start = image_ref = convert_image_to_tensor(input_ref_images[0]).to(self.device) - lat_y = torch.concat(self.vae.encode([image_ref.unsqueeze(1).to(self.device), input_frames.to(self.device)], VAE_tile_size), dim=1) + image_ref = input_ref_images[0].to(self.device) + clip_image_start = image_ref.squeeze(1) + lat_y = torch.concat(self.vae.encode([image_ref, input_frames.to(self.device)], VAE_tile_size), dim=1) y = torch.concat([msk, lat_y]) kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)}) lat_y = msk = msk_control = msk_ref = pose_pixels = None @@ -701,12 +624,11 @@ class WanAny2V: # Phantom if phantom: - input_ref_images_neg = None - if input_ref_images != None: # Phantom Ref images - input_ref_images = self.get_vae_latents(input_ref_images, self.device) - input_ref_images_neg = torch.zeros_like(input_ref_images) - ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0 - trim_frames = input_ref_images.shape[1] + lat_input_ref_images_neg = None + if input_ref_images is not None: # Phantom Ref images + lat_input_ref_images = self.get_vae_latents(input_ref_images, self.device) + lat_input_ref_images_neg = torch.zeros_like(lat_input_ref_images) + ref_images_count = trim_frames = lat_input_ref_images.shape[1] if ti2v: if input_video is None: @@ -721,25 +643,23 @@ class WanAny2V: # Vace if vace : # vace context encode - input_frames = [u.to(self.device) for u in input_frames] - input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] - input_masks = [u.to(self.device) for u in input_masks] + input_frames = [input_frames.to(self.device)] +([] if input_frames2 is None else [input_frames2.to(self.device)]) + input_masks = [input_masks.to(self.device)] + ([] if input_masks2 is None else [input_masks2.to(self.device)]) + input_ref_images = None if input_ref_images is None else [ u.to(self.device) for u in input_ref_images] + input_ref_masks = None if input_ref_masks is None else [ None if u is None else u.to(self.device) for u in input_ref_masks] ref_images_before = True - if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask] z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents ) m0 = self.vace_encode_masks(input_masks, input_ref_images) - if self.background_mask != None: - color_reference_frame = input_ref_images[0][0].clone() - zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size ) - mbg = self.vace_encode_masks(self.background_mask, None) + if input_ref_masks is not None and len(input_ref_masks) > 0 and input_ref_masks[0] is not None: + color_reference_frame = input_ref_images[0].clone() + zbg = self.vace_encode_frames( input_ref_images[:1] * len(input_frames), None, masks=input_ref_masks[0], tile_size = VAE_tile_size ) + mbg = self.vace_encode_masks(input_ref_masks[:1] * len(input_frames), None) for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg): zz0[:, 0:1] = zzbg mm0[:, 0:1] = mmbg - - self.background_mask = zz0 = mm0 = zzbg = mmbg = None - z = self.vace_latent(z0, m0) - - ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0 + zz0 = mm0 = zzbg = mmbg = None + z = [torch.cat([zz, mm], dim=0) for zz, mm in zip(z0, m0)] + ref_images_count = len(input_ref_images) if input_ref_images is not None and input_ref_images is not None else 0 context_scale = context_scale if context_scale != None else [1.0] * len(z) kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count }) if overlapped_latents != None : @@ -747,15 +667,8 @@ class WanAny2V: extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0) if prefix_frames_count > 0: color_reference_frame = input_frames[0][:, prefix_frames_count -1:prefix_frames_count].clone() - - target_shape = list(z0[0].shape) - target_shape[0] = int(target_shape[0] / 2) - lat_h, lat_w = target_shape[-2:] - height = self.vae_stride[1] * lat_h - width = self.vae_stride[2] * lat_w - - else: - target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2]) + lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] + target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, lat_h, lat_w) if multitalk: if audio_proj is None: @@ -860,7 +773,9 @@ class WanAny2V: apg_norm_threshold = 55 text_momentumbuffer = MomentumBuffer(apg_momentum) audio_momentumbuffer = MomentumBuffer(apg_momentum) - + input_frames = input_frames2 = input_masks =input_masks2 = input_video = input_ref_images = input_ref_masks = pre_video_frame = None + gc.collect() + torch.cuda.empty_cache() # denoising trans = self.model @@ -878,7 +793,7 @@ class WanAny2V: kwargs.update({"t": timestep, "current_step": start_step_no + i}) kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None - if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step: + if denoising_strength < 1 and i <= injection_denoising_step: sigma = t / 1000 noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) if inject_from_start: @@ -912,8 +827,8 @@ class WanAny2V: any_guidance = guide_scale != 1 if phantom: gen_args = { - "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 + - [ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]), + "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 + + [ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]), "context": [context, context_null, context_null] , } elif fantasy: diff --git a/models/wan/df_handler.py b/models/wan/df_handler.py index 7a5d3ea..31d5da6 100644 --- a/models/wan/df_handler.py +++ b/models/wan/df_handler.py @@ -21,6 +21,7 @@ class family_handler(): extra_model_def["fps"] =fps extra_model_def["frames_minimum"] = 17 extra_model_def["frames_steps"] = 20 + extra_model_def["latent_size"] = 4 extra_model_def["sliding_window"] = True extra_model_def["skip_layer_guidance"] = True extra_model_def["tea_cache"] = True diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index ff1570f..f2a8fb4 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -114,7 +114,6 @@ class family_handler(): "tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels), "mag_cache" : True, "keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"], - "convert_image_guide_to_video" : True, "sample_solvers":[ ("unipc", "unipc"), ("euler", "euler"), @@ -175,6 +174,8 @@ class family_handler(): extra_model_def["forced_guide_mask_inputs"] = True extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)" extra_model_def["background_ref_outpainted"] = False + extra_model_def["return_image_refs_tensor"] = True + extra_model_def["guide_inpaint_color"] = 0 @@ -196,15 +197,15 @@ class family_handler(): "letters_filter": "KFI", } - extra_model_def["lock_image_refs_ratios"] = True extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames" extra_model_def["video_guide_outpainting"] = [0,1] extra_model_def["pad_guide_video"] = True extra_model_def["guide_inpaint_color"] = 127.5 extra_model_def["forced_guide_mask_inputs"] = True + extra_model_def["return_image_refs_tensor"] = True if base_model_type in ["standin"]: - extra_model_def["lock_image_refs_ratios"] = True + extra_model_def["fit_into_canvas_image_refs"] = 0 extra_model_def["image_ref_choices"] = { "choices": [ ("No Reference Image", ""), @@ -480,6 +481,7 @@ class family_handler(): ui_defaults.update({ "video_prompt_type": "PVBXAKI", "mask_expand": 20, + "audio_prompt_type_value": "R", }) if text_oneframe_overlap(base_model_type): diff --git a/shared/utils/utils.py b/shared/utils/utils.py index 6e8a98b..bb2d5ff 100644 --- a/shared/utils/utils.py +++ b/shared/utils/utils.py @@ -32,6 +32,14 @@ def seed_everything(seed: int): if torch.backends.mps.is_available(): torch.mps.manual_seed(seed) +def has_video_file_extension(filename): + extension = os.path.splitext(filename)[-1].lower() + return extension in [".mp4"] + +def has_image_file_extension(filename): + extension = os.path.splitext(filename)[-1].lower() + return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"] + def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ): import math @@ -94,7 +102,7 @@ def get_video_info(video_path): return fps, width, height, frame_count -def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, return_PIL = True) -> torch.Tensor: +def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, target_fps = None, return_PIL = True) -> torch.Tensor: """Extract nth frame from video as PyTorch tensor normalized to [-1, 1].""" cap = cv2.VideoCapture(file_name) @@ -102,7 +110,10 @@ def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool raise ValueError(f"Cannot open video: {file_name}") total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - + fps = round(cap.get(cv2.CAP_PROP_FPS)) + if target_fps is not None: + frame_no = round(target_fps * frame_no /fps) + # Handle out of bounds if frame_no >= total_frames or frame_no < 0: if return_last_if_missing: @@ -175,10 +186,15 @@ def remove_background(img, session=None): def convert_image_to_tensor(image): return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0) -def convert_tensor_to_image(t, frame_no = 0): +def convert_tensor_to_image(t, frame_no = 0, mask_levels = False): if len(t.shape) == 4: t = t[:, frame_no] - return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) + if t.shape[0]== 1: + t = t.expand(3,-1,-1) + if mask_levels: + return Image.fromarray(t.clone().mul_(255).permute(1,2,0).to(torch.uint8).cpu().numpy()) + else: + return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) def save_image(tensor_image, name, frame_no = -1): convert_tensor_to_image(tensor_image, frame_no).save(name) @@ -257,7 +273,7 @@ def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fi image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) return image, new_height, new_width -def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5 ): +def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5, return_tensor = False ): if rm_background: session = new_session() @@ -266,7 +282,7 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg for i, img in enumerate(img_list): width, height = img.size resized_mask = None - if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2: + if any_background_ref == 1 and i==0 or any_background_ref == 2: if outpainting_dims is not None and background_ref_outpainted: resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True) elif img.size != (budget_width, budget_height): @@ -291,7 +307,10 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) : # resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') - output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200, + if return_tensor: + output_list.append(convert_image_to_tensor(resized_image).unsqueeze(1)) + else: + output_list.append(resized_image) output_mask_list.append(resized_mask) return output_list, output_mask_list @@ -346,47 +365,46 @@ def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu return ref_img.to(device), canvas -def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, extract_guide_from_window_start = False, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None ): +def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None, device ="cpu"): src_videos, src_masks = [], [] - inpaint_color = guide_inpaint_color/127.5 - 1 - prepend_count = pre_video_guide.shape[1] if not extract_guide_from_window_start and pre_video_guide is not None else 0 + inpaint_color_compressed = guide_inpaint_color/127.5 - 1 + prepend_count = pre_video_guide.shape[1] if pre_video_guide is not None else 0 for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)): - src_video = src_mask = None - if cur_video_guide is not None: - src_video = cur_video_guide.permute(3, 0, 1, 2).float().div_(127.5).sub_(1.) # c, f, h, w - if cur_video_mask is not None and any_mask: - src_mask = cur_video_mask.permute(3, 0, 1, 2).float().div_(255)[0:1] # c, f, h, w - if pre_video_guide is not None and not extract_guide_from_window_start: + src_video, src_mask = cur_video_guide, cur_video_mask + if pre_video_guide is not None: src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1) if any_mask: src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1) - if src_video is None: - if any_guide_padding: - src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color, dtype = torch.float, device= "cpu") - if any_mask: - src_mask = torch.zeros_like(src_video[0:1]) - elif src_video.shape[1] < current_video_length: - if any_guide_padding: - src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color, dtype = src_video.dtype, device= src_video.device) ], dim=1) - if cur_video_mask is not None and any_mask: - src_mask = torch.cat([src_mask, torch.full( (1, current_video_length - src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1) + + if any_guide_padding: + if src_video is None: + src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color_compressed, dtype = torch.float, device= device) + elif src_video.shape[1] < current_video_length: + src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color_compressed, dtype = src_video.dtype, device= src_video.device) ], dim=1) + elif src_video is not None: + new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1 + src_video = src_video[:, :new_num_frames] + + if any_mask and src_video is not None: + if src_mask is None: + src_mask = torch.ones_like(src_video[:1]) + elif src_mask.shape[1] < src_video.shape[1]: + src_mask = torch.cat([src_mask, torch.full( (1, src_video.shape[1]- src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1) else: - new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1 - src_video = src_video[:, :new_num_frames] - if any_mask: - src_mask = src_mask[:, :new_num_frames] + src_mask = src_mask[:, :src_video.shape[1]] - for k, keep in enumerate(keep_video_guide_frames): - if not keep: - pos = prepend_count + k - src_video[:, pos:pos+1] = inpaint_color - src_mask[:, pos:pos+1] = 1 - - for k, frame in enumerate(inject_frames): - if frame != None: - pos = prepend_count + k - src_video[:, pos:pos+1], src_mask[:, pos:pos+1] = fit_image_into_canvas(frame, image_size, inpaint_color, device, True, outpainting_dims, return_mask= True) + if src_video is not None : + for k, keep in enumerate(keep_video_guide_frames): + if not keep: + pos = prepend_count + k + src_video[:, pos:pos+1] = inpaint_color_compressed + if any_mask: src_mask[:, pos:pos+1] = 1 + for k, frame in enumerate(inject_frames): + if frame != None: + pos = prepend_count + k + src_video[:, pos:pos+1], msk = fit_image_into_canvas(frame, image_size, guide_inpaint_color, device, True, outpainting_dims, return_mask= any_mask) + if any_mask: src_mask[:, pos:pos+1] = msk src_videos.append(src_video) src_masks.append(src_mask) return src_videos, src_masks diff --git a/wgp.py b/wgp.py index 0e6a0b2..6a7bca1 100644 --- a/wgp.py +++ b/wgp.py @@ -24,6 +24,7 @@ from shared.utils import notification_sound from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background, rgb_bw_to_rgba_mask from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions +from shared.utils.utils import has_video_file_extension, has_image_file_extension from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image from shared.utils.audio_video import save_image_metadata, read_image_metadata from shared.match_archi import match_nvidia_architecture @@ -62,7 +63,7 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.61" +WanGP_version = "8.7" settings_version = 2.35 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -1942,7 +1943,8 @@ def get_model_min_frames_and_step(model_type): mode_def = get_model_def(model_type) frames_minimum = mode_def.get("frames_minimum", 5) frames_steps = mode_def.get("frames_steps", 4) - return frames_minimum, frames_steps + latent_size = mode_def.get("latent_size", frames_steps) + return frames_minimum, frames_steps, latent_size def get_model_fps(model_type): mode_def = get_model_def(model_type) @@ -3459,7 +3461,7 @@ def select_video(state, input_file_list, event_data: gr.EventData): if len(video_other_prompts) >0 : values += [video_other_prompts] labels += ["Other Prompts"] - if len(video_outpainting) >0 and any_letters(video_image_prompt_type, "VFK"): + if len(video_outpainting) >0: values += [video_outpainting] labels += ["Outpainting"] video_sample_solver = configs.get("sample_solver", "") @@ -3532,6 +3534,11 @@ def convert_image(image): return cast(Image, ImageOps.exif_transpose(image)) def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='torch'): + if isinstance(video_in, str) and has_image_file_extension(video_in): + video_in = Image.open(video_in) + if isinstance(video_in, Image.Image): + return torch.from_numpy(np.array(video_in).astype(np.uint8)).unsqueeze(0) + from shared.utils.utils import resample import decord @@ -3653,19 +3660,22 @@ def get_preprocessor(process_type, inpaint_color): def process_images_multithread(image_processor, items, process_type, wrap_in_list = True, max_workers: int = os.cpu_count()/ 2) : if not items: return [] - max_workers = 11 + import concurrent.futures start_time = time.time() # print(f"Preprocessus:{process_type} started") if process_type in ["prephase", "upsample"]: if wrap_in_list : items = [ [img] for img in items] - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)} - results = [None] * len(items) - for future in concurrent.futures.as_completed(futures): - idx = futures[future] - results[idx] = future.result() + if max_workers == 1: + results = [image_processor(img) for img in items] + else: + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)} + results = [None] * len(items) + for future in concurrent.futures.as_completed(futures): + idx = futures[future] + results[idx] = future.result() if wrap_in_list: results = [ img[0] for img in results] @@ -3677,55 +3687,6 @@ def process_images_multithread(image_processor, items, process_type, wrap_in_lis return results -def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canvas = False, fit_crop = False, block_size= 16, expand_scale = 2, outpainting_dims = None, inpaint_color = 127): - frame_width, frame_height = input_image.size - - if fit_crop: - input_image = rescale_and_crop(input_image, width, height) - if input_mask is not None: - input_mask = rescale_and_crop(input_mask, width, height) - return input_image, input_mask - - if outpainting_dims != None: - if fit_canvas != None: - frame_height, frame_width = get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims) - else: - frame_height, frame_width = height, width - - if fit_canvas != None: - height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size) - - if outpainting_dims != None: - final_height, final_width = height, width - height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1) - - if fit_canvas != None or outpainting_dims != None: - input_image = input_image.resize((width, height), resample=Image.Resampling.LANCZOS) - if input_mask is not None: - input_mask = input_mask.resize((width, height), resample=Image.Resampling.LANCZOS) - - if expand_scale != 0 and input_mask is not None: - kernel_size = abs(expand_scale) - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) - op_expand = cv2.dilate if expand_scale > 0 else cv2.erode - input_mask = np.array(input_mask) - input_mask = op_expand(input_mask, kernel, iterations=3) - input_mask = Image.fromarray(input_mask) - - if outpainting_dims != None: - inpaint_color = inpaint_color / 127.5-1 - image = convert_image_to_tensor(input_image) - full_frame= torch.full( (image.shape[0], final_height, final_width), inpaint_color, dtype= torch.float, device= image.device) - full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = image - input_image = convert_tensor_to_image(full_frame) - - if input_mask is not None: - mask = convert_image_to_tensor(input_mask) - full_frame= torch.full( (mask.shape[0], final_height, final_width), 1, dtype= torch.float, device= mask.device) - full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = mask - input_mask = convert_tensor_to_image(full_frame) - - return input_image, input_mask def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512): if not input_video_path or max_frames <= 0: return None, None @@ -3780,6 +3741,8 @@ def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_fr save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None) return face_tensor +def get_default_workers(): + return os.cpu_count()/ 2 def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1): @@ -3906,8 +3869,8 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, return (target_frame, frame, mask) else: return (target_frame, None, None) - - proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False) + max_workers = get_default_workers() + proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False, max_workers=max_workers) proc_list, proc_list_outside, proc_mask = [None] * len(proc_lists), [None] * len(proc_lists), [None] * len(proc_lists) for frame_idx, frame_group in enumerate(proc_lists): proc_list[frame_idx], proc_list_outside[frame_idx], proc_mask[frame_idx] = frame_group @@ -3916,11 +3879,11 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, mask_video = None if preproc2 != None: - proc_list2 = process_images_multithread(preproc2, proc_list, process_type2) + proc_list2 = process_images_multithread(preproc2, proc_list, process_type2, max_workers=max_workers) #### to be finished ...or not - proc_list = process_images_multithread(preproc, proc_list, process_type) + proc_list = process_images_multithread(preproc, proc_list, process_type, max_workers=max_workers) if any_mask: - proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask) + proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask, max_workers=max_workers) else: proc_list_outside = proc_mask = len(proc_list) * [None] @@ -3938,7 +3901,7 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, full_frame= torch.full( (final_height, final_width, mask.shape[-1]), 255, dtype= torch.uint8, device= mask.device) full_frame[margin_top:margin_top+height, margin_left:margin_left+width] = mask mask = full_frame - masks.append(mask) + masks.append(mask[:, :, 0:1].clone()) else: masked_frame = processed_img @@ -3958,13 +3921,13 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, proc_list[frame_no] = proc_list_outside[frame_no] = proc_mask[frame_no] = None - if args.save_masks: - from preprocessing.dwpose.pose import save_one_video - saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ] - save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None) - if any_mask: - saved_masks = [mask.cpu().numpy() for mask in masks ] - save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None) + # if args.save_masks: + # from preprocessing.dwpose.pose import save_one_video + # saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ] + # save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None) + # if any_mask: + # saved_masks = [mask.cpu().numpy() for mask in masks ] + # save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None) preproc = None preproc_outside = None gc.collect() @@ -3972,8 +3935,10 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, if pad_frames > 0: masked_frames = masked_frames[0] * pad_frames + masked_frames if any_mask: masked_frames = masks[0] * pad_frames + masks + masked_frames = torch.stack(masked_frames).permute(-1,0,1,2).float().div_(127.5).sub_(1.) + masks = torch.stack(masks).permute(-1,0,1,2).float().div_(255) if any_mask else None - return torch.stack(masked_frames), torch.stack(masks) if any_mask else None + return masked_frames, masks def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size = 16): @@ -4102,7 +4067,7 @@ def perform_spatial_upsampling(sample, spatial_upsampling): frames_to_upsample = [sample[:, i] for i in range( sample.shape[1]) ] def upsample_frames(frame): return resize_lanczos(frame, h, w).unsqueeze(1) - sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False), dim=1) + sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False, max_workers=get_default_workers()), dim=1) frames_to_upsample = None return sample @@ -4609,17 +4574,13 @@ def generate_video( batch_size = 1 temp_filenames_list = [] - convert_image_guide_to_video = model_def.get("convert_image_guide_to_video", False) - if convert_image_guide_to_video: - if image_guide is not None and isinstance(image_guide, Image.Image): - video_guide = convert_image_to_video(image_guide) - temp_filenames_list.append(video_guide) - image_guide = None + if image_guide is not None and isinstance(image_guide, Image.Image): + video_guide = image_guide + image_guide = None - if image_mask is not None and isinstance(image_mask, Image.Image): - video_mask = convert_image_to_video(image_mask) - temp_filenames_list.append(video_mask) - image_mask = None + if image_mask is not None and isinstance(image_mask, Image.Image): + video_mask = image_mask + image_mask = None if model_def.get("no_background_removal", False): remove_background_images_ref = 0 @@ -4711,22 +4672,12 @@ def generate_video( device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576 guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5) extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False) - i2v = test_class_i2v(model_type) - diffusion_forcing = "diffusion_forcing" in model_filename - t2v = base_model_type in ["t2v"] - ltxv = "ltxv" in model_filename - vace = test_vace_module(base_model_type) - hunyuan_t2v = "hunyuan_video_720" in model_filename - hunyuan_i2v = "hunyuan_video_i2v" in model_filename hunyuan_custom = "hunyuan_video_custom" in model_filename hunyuan_custom_audio = hunyuan_custom and "audio" in model_filename hunyuan_custom_edit = hunyuan_custom and "edit" in model_filename hunyuan_avatar = "hunyuan_video_avatar" in model_filename fantasy = base_model_type in ["fantasy"] multitalk = model_def.get("multitalk_class", False) - standin = model_def.get("standin_class", False) - infinitetalk = base_model_type in ["infinitetalk"] - animate = base_model_type in ["animate"] if "B" in audio_prompt_type or "X" in audio_prompt_type: from models.wan.multitalk.multitalk import parse_speakers_locations @@ -4763,9 +4714,9 @@ def generate_video( sliding_window_size = current_video_length reuse_frames = 0 - _, latent_size = get_model_min_frames_and_step(model_type) - if diffusion_forcing: latent_size = 4 + _, _, latent_size = get_model_min_frames_and_step(model_type) original_image_refs = image_refs + image_refs = None if image_refs is None else [] + image_refs # work on a copy as it is going to be modified # image_refs = None # nb_frames_positions= 0 # Output Video Ratio Priorities: @@ -4889,6 +4840,7 @@ def generate_video( initial_total_windows = 0 discard_last_frames = sliding_window_discard_last_frames default_requested_frames_to_generate = current_video_length + nb_frames_positions = 0 if sliding_window: initial_total_windows= compute_sliding_window_no(default_requested_frames_to_generate, sliding_window_size, discard_last_frames, reuse_frames) current_video_length = sliding_window_size @@ -4907,7 +4859,7 @@ def generate_video( if repeat_no >= total_generation: break repeat_no +=1 gen["repeat_no"] = repeat_no - src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = src_faces = None + src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = None prefix_video = pre_video_frame = None source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before ) @@ -4963,7 +4915,6 @@ def generate_video( return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) ) refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {} - src_ref_images, src_ref_masks = image_refs, None image_start_tensor = image_end_tensor = None if window_no == 1 and (video_source is not None or image_start is not None): if image_start is not None: @@ -5020,7 +4971,7 @@ def generate_video( if len(pos) > 0: if pos in ["L", "l"]: cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length - if cur_end_pos >= last_frame_no and not joker_used: + if cur_end_pos >= last_frame_no-1 and not joker_used: joker_used = True cur_end_pos = last_frame_no -1 project_window_no += 1 @@ -5036,141 +4987,53 @@ def generate_video( frames_to_inject[pos] = image_refs[i] + video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None if video_guide is not None: keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) if len(error) > 0: raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame + extra_control_frames = model_def.get("extra_control_frames", 0) + if extra_control_frames > 0 and aligned_guide_start_frame >= extra_control_frames: guide_frames_extract_start -= extra_control_frames + keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else [] keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ] guide_frames_extract_count = len(keep_frames_parsed) + # Extract Faces to video if "B" in video_prompt_type: send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")]) src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps) if src_faces is not None and src_faces.shape[1] < current_video_length: src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1) - if vace or animate: - video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None - context_scale = [ control_net_weight] - if "V" in video_prompt_type: - process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None) - preprocess_type, preprocess_type2 = "raw", None - for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")): - if process_num == 0: - preprocess_type = process_map_video_guide.get(process_letter, "raw") - else: - preprocess_type2 = process_map_video_guide.get(process_letter, None) - status_info = "Extracting " + processes_names[preprocess_type] - extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask]) - if len(extra_process_list) == 1: - status_info += " and " + processes_names[extra_process_list[0]] - elif len(extra_process_list) == 2: - status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]] - if preprocess_type2 is not None: - context_scale = [ control_net_weight /2, control_net_weight2 /2] - send_cmd("progress", [0, get_latest_status(state, status_info)]) - inpaint_color = 0 if preprocess_type=="pose" else guide_inpaint_color - video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color ) - if preprocess_type2 != None: - video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 ) + # Sparse Video to Video + sparse_video_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, target_fps = fps, return_PIL = True) if "R" in video_prompt_type else None - if video_guide_processed != None: - if sample_fit_canvas != None: - image_size = video_guide_processed.shape[-3: -1] - sample_fit_canvas = None - refresh_preview["video_guide"] = Image.fromarray(video_guide_processed[0].cpu().numpy()) - if video_guide_processed2 != None: - refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())] - if video_mask_processed != None: - refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy()) - - frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame] - - if not vace and (any_letters(video_prompt_type ,"FV") or model_def.get("forced_guide_mask_inputs", False)): - any_mask = True - any_guide_padding = model_def.get("pad_guide_video", False) - from shared.utils.utils import prepare_video_guide_and_mask - src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed, video_guide_processed2], - [video_mask_processed, video_mask_processed2], - pre_video_guide, image_size, current_video_length, latent_size, - any_mask, any_guide_padding, guide_inpaint_color, extract_guide_from_window_start, - keep_frames_parsed, frames_to_inject_parsed , outpainting_dims) - - src_video, src_video2 = src_videos - src_mask, src_mask2 = src_masks - if src_video is None: - abort = True - break - if src_faces is not None: - if src_faces.shape[1] < src_video.shape[1]: - src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1) - else: - src_faces = src_faces[:, :src_video.shape[1]] - if args.save_masks: - save_video( src_video, "masked_frames.mp4", fps) - if src_video2 is not None: - save_video( src_video2, "masked_frames2.mp4", fps) - if any_mask: - save_video( src_mask, "masks.mp4", fps, value_range=(0, 1)) - - elif ltxv: - preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw") - status_info = "Extracting " + processes_names[preprocess_type] - send_cmd("progress", [0, get_latest_status(state, status_info)]) - # start one frame ealier to facilitate latents merging later - src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size ) - if src_video != None: - src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ] - refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) - refresh_preview["video_mask"] = None - src_video = src_video.permute(3, 0, 1, 2) - src_video = src_video.float().div_(127.5).sub_(1.) # c, f, h, w - if sample_fit_canvas != None: - image_size = src_video.shape[-2:] - sample_fit_canvas = None - - elif hunyuan_custom_edit: - if "P" in video_prompt_type: - progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")] + # Generic Video Preprocessing + process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None) + preprocess_type, preprocess_type2 = "raw", None + for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")): + if process_num == 0: + preprocess_type = process_map_video_guide.get(process_letter, "raw") else: - progress_args = [0, get_latest_status(state,"Extracting Video and Mask")] + preprocess_type2 = process_map_video_guide.get(process_letter, None) + status_info = "Extracting " + processes_names[preprocess_type] + extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask]) + if len(extra_process_list) == 1: + status_info += " and " + processes_names[extra_process_list[0]] + elif len(extra_process_list) == 2: + status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]] + context_scale = [control_net_weight /2, control_net_weight2 /2] if preprocess_type2 is not None else [control_net_weight] + if not (preprocess_type == "identity" and preprocess_type2 is None and video_mask is None):send_cmd("progress", [0, get_latest_status(state, status_info)]) + inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask == "inpaint" else guide_inpaint_color + video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide if sparse_video_image is None else sparse_video_image, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color, block_size = block_size ) + if preprocess_type2 != None: + video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2, block_size = block_size ) - send_cmd("progress", progress_args) - src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0) - refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) - if src_mask != None: - refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy()) - - elif "R" in video_prompt_type: # sparse video to video - src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True) - src_image, _, _ = calculate_dimensions_and_resize_image(src_image, image_size[0], image_size[1 ], sample_fit_canvas, fit_crop, block_size = block_size) - refresh_preview["video_guide"] = src_image - src_video = convert_image_to_tensor(src_image).unsqueeze(1) - if sample_fit_canvas != None: - image_size = src_video.shape[-2:] - sample_fit_canvas = None - - else: # video to video - video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size= block_size) - if video_guide_processed is None: - src_video = pre_video_guide - else: - if sample_fit_canvas != None: - image_size = video_guide_processed.shape[-3: -1] - sample_fit_canvas = None - src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2) - if pre_video_guide != None: - src_video = torch.cat( [pre_video_guide, src_video], dim=1) - elif image_guide is not None: - new_image_guide, new_image_mask = preprocess_image_with_mask(image_guide, image_mask, image_size[0], image_size[1], fit_canvas = sample_fit_canvas, fit_crop= fit_crop, block_size= block_size, expand_scale = mask_expand, outpainting_dims=outpainting_dims) - if sample_fit_canvas is not None: - image_size = (new_image_guide.size[1], new_image_guide.size[0]) + if video_guide_processed is not None and sample_fit_canvas is not None: + image_size = video_guide_processed.shape[-2:] sample_fit_canvas = None - refresh_preview["image_guide"] = new_image_guide - if new_image_mask is not None: - refresh_preview["image_mask"] = new_image_mask if window_no == 1 and image_refs is not None and len(image_refs) > 0: if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) : @@ -5192,45 +5055,68 @@ def generate_video( image_refs[i] = rescale_and_crop(img, default_image_size[1], default_image_size[0]) refresh_preview["image_refs"] = image_refs - if len(image_refs) > nb_frames_positions: + if len(image_refs) > nb_frames_positions: + src_ref_images = image_refs[nb_frames_positions:] if remove_background_images_ref > 0: send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) - # keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested - image_refs[nb_frames_positions:], src_ref_masks = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0], + + src_ref_images, src_ref_masks = resize_and_remove_background(src_ref_images , image_size[1], image_size[0], remove_background_images_ref > 0, any_background_ref, - fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1, + fit_into_canvas= model_def.get("fit_into_canvas_image_refs", 1), block_size=block_size, outpainting_dims =outpainting_dims, - background_ref_outpainted = model_def.get("background_ref_outpainted", True) ) - refresh_preview["image_refs"] = image_refs + background_ref_outpainted = model_def.get("background_ref_outpainted", True), + return_tensor= model_def.get("return_image_refs_tensor", False) ) + - - if vace : - image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications - - src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2], - [video_mask_processed] if video_guide_processed2 == None else [video_mask_processed, video_mask_processed2], - [image_refs_copy] if video_guide_processed2 == None else [image_refs_copy, image_refs_copy], - current_video_length, image_size = image_size, device ="cpu", - keep_video_guide_frames=keep_frames_parsed, - pre_src_video = [pre_video_guide] if video_guide_processed2 == None else [pre_video_guide, pre_video_guide], - inject_frames= frames_to_inject_parsed, - outpainting_dims = outpainting_dims, - any_background_ref = any_background_ref - ) - if len(frames_to_inject_parsed) or any_background_ref: - new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + 0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] - if any_background_ref: - new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:] + frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame] + if video_guide is not None or len(frames_to_inject_parsed) > 0 or model_def.get("forced_guide_mask_inputs", False): + any_mask = video_mask is not None or model_def.get("forced_guide_mask_inputs", False) + any_guide_padding = model_def.get("pad_guide_video", False) + from shared.utils.utils import prepare_video_guide_and_mask + src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed] + ([] if video_guide_processed2 is None else [video_guide_processed2]), + [video_mask_processed] + ([] if video_mask_processed2 is None else [video_mask_processed2]), + None if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else pre_video_guide, + image_size, current_video_length, latent_size, + any_mask, any_guide_padding, guide_inpaint_color, + keep_frames_parsed, frames_to_inject_parsed , outpainting_dims) + video_guide_processed = video_guide_processed2 = video_mask_processed = video_mask_processed2 = None + if len(src_videos) == 1: + src_video, src_video2, src_mask, src_mask2 = src_videos[0], None, src_masks[0], None + else: + src_video, src_video2 = src_videos + src_mask, src_mask2 = src_masks + src_videos = src_masks = None + if src_video is None: + abort = True + break + if src_faces is not None: + if src_faces.shape[1] < src_video.shape[1]: + src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1) else: - new_image_refs += image_refs[nb_frames_positions:] - refresh_preview["image_refs"] = new_image_refs - new_image_refs = None - - if sample_fit_canvas != None: - image_size = src_video[0].shape[-2:] - sample_fit_canvas = None + src_faces = src_faces[:, :src_video.shape[1]] + if video_guide is not None or len(frames_to_inject_parsed) > 0: + if args.save_masks: + if src_video is not None: save_video( src_video, "masked_frames.mp4", fps) + if src_video2 is not None: save_video( src_video2, "masked_frames2.mp4", fps) + if any_mask: save_video( src_mask, "masks.mp4", fps, value_range=(0, 1)) + if video_guide is not None: + preview_frame_no = 0 if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else (guide_start_frame - window_start_frame) + refresh_preview["video_guide"] = convert_tensor_to_image(src_video, preview_frame_no) + if src_video2 is not None: + refresh_preview["video_guide"] = [refresh_preview["video_guide"], convert_tensor_to_image(src_video2, preview_frame_no)] + if src_mask is not None and video_mask is not None: + refresh_preview["video_mask"] = convert_tensor_to_image(src_mask, preview_frame_no, mask_levels = True) + if src_ref_images is not None or nb_frames_positions: + if len(frames_to_inject_parsed): + new_image_refs = [convert_tensor_to_image(src_video, frame_no + (0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame)) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] + else: + new_image_refs = [] + if src_ref_images is not None: + new_image_refs += [convert_tensor_to_image(img) if torch.is_tensor(img) else img for img in src_ref_images ] + refresh_preview["image_refs"] = new_image_refs + new_image_refs = None if len(refresh_preview) > 0: new_inputs= locals() @@ -5339,8 +5225,6 @@ def generate_video( pre_video_frame = pre_video_frame, original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [], image_refs_relative_size = image_refs_relative_size, - image_guide= new_image_guide, - image_mask= new_image_mask, outpainting_dims = outpainting_dims, ) except Exception as e: @@ -6320,8 +6204,11 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None pop += ["image_refs_relative_size"] if not vace: - pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2"] + pop += ["frames_positions", "control_net_weight", "control_net_weight2"] + if model_def.get("video_guide_outpainting", None) is None: + pop += ["video_guide_outpainting"] + if not (vace or t2v): pop += ["min_frames_if_references"] @@ -6506,13 +6393,6 @@ def eject_video_from_gallery(state, input_file_list, choice): choice = min(choice, len(file_list)) return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0) -def has_video_file_extension(filename): - extension = os.path.splitext(filename)[-1].lower() - return extension in [".mp4"] - -def has_image_file_extension(filename): - extension = os.path.splitext(filename)[-1].lower() - return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"] def add_videos_to_gallery(state, input_file_list, choice, files_to_load): gen = get_gen_info(state) if files_to_load == None: @@ -7881,7 +7761,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non elif recammaster: video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", get_max_frames(81)), step=4, label="Number of frames (16 = 1s), locked", interactive= False, visible = True) else: - min_frames, frames_step = get_model_min_frames_and_step(base_model_type) + min_frames, frames_step, _ = get_model_min_frames_and_step(base_model_type) current_video_length = ui_defaults.get("video_length", 81 if get_model_family(base_model_type)=="wan" else 97) @@ -8059,7 +7939,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non MMAudio_neg_prompt = gr.Text(ui_defaults.get("MMAudio_neg_prompt", ""), label="Negative Prompt (1 or 2 keywords)") - with gr.Column(visible = (t2v or vace) and not fantasy) as audio_prompt_type_remux_row: + with gr.Column(visible = any_control_video) as audio_prompt_type_remux_row: gr.Markdown("You may transfer the existing audio tracks of a Control Video") audio_prompt_type_remux = gr.Dropdown( choices=[ @@ -8284,16 +8164,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_info = gr.HTML(visible=True, min_height=100, value=get_default_video_info()) with gr.Row(**default_visibility) as video_buttons_row: video_info_extract_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm") - video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video ) video_info_to_video_source_btn = gr.Button("To Video Source", min_width= 1, size ="sm", visible = any_video_source) + video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video ) video_info_eject_video_btn = gr.Button("Eject Video", min_width= 1, size ="sm") with gr.Row(**default_visibility) as image_buttons_row: video_info_extract_image_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm") video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", min_width= 1, visible = any_start_image ) video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", min_width= 1, visible = any_end_image) - video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image ) video_info_to_image_mask_btn = gr.Button("To Mask Image", min_width= 1, size ="sm", visible = any_image_mask and False) video_info_to_reference_image_btn = gr.Button("To Reference Image", min_width= 1, size ="sm", visible = any_reference_image) + video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image ) video_info_eject_image_btn = gr.Button("Eject Image", min_width= 1, size ="sm") with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab: with gr.Group(elem_classes= "postprocess"):