Vace Contenders are in Town

This commit is contained in:
DeepBeepMeep 2025-09-23 23:04:44 +02:00
parent 84010bd861
commit e28c95ae91
18 changed files with 1155 additions and 442 deletions

View File

@ -20,6 +20,14 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates :
### September 23 2025: WanGP v8.7 - Here Are Two New Contenders in the Vace Arena !
So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion tranfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package.
### September 15 2025: WanGP v8.6 - Attack of the Clones
- The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**)

View File

@ -2,7 +2,7 @@
"model": {
"name": "Wan2.2 Lucy Edit 5B",
"architecture": "lucy_edit",
"description": "Lucy Edit Dev is a video editing model that performs instruction-guided edits on videos using free-text prompts \u2014 it supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.",
"description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors",
@ -10,6 +10,7 @@
],
"group": "wan2_2"
},
"prompt": "change the clothes to red",
"video_length": 81,
"guidance_scale": 5,
"flow_shift": 5,

View File

@ -0,0 +1,16 @@
{
"model": {
"name": "Wan2.2 FastWan Lucy Edit 5B",
"architecture": "lucy_edit",
"description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.",
"URLs": "lucy_edit",
"group": "wan2_2",
"loras": "ti2v_2_2"
},
"prompt": "change the clothes to red",
"video_length": 81,
"guidance_scale": 1,
"flow_shift": 3,
"num_inference_steps": 5,
"resolution": "1280x720"
}

View File

@ -56,7 +56,7 @@ class family_handler():
}
extra_model_def["lock_image_refs_ratios"] = True
extra_model_def["fit_into_canvas_image_refs"] = 0
return extra_model_def

View File

@ -142,8 +142,8 @@ class model_factory:
n_prompt: str = None,
sampling_steps: int = 20,
input_ref_images = None,
image_guide= None,
image_mask= None,
input_frames= None,
input_masks= None,
width= 832,
height=480,
embedded_guidance_scale: float = 2.5,
@ -197,10 +197,12 @@ class model_factory:
for new_img in input_ref_images[1:]:
stiched = stitch_images(stiched, new_img)
input_ref_images = [stiched]
elif image_guide is not None:
input_ref_images = [image_guide]
elif input_frames is not None:
input_ref_images = [convert_tensor_to_image(input_frames) ]
else:
input_ref_images = None
image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True)
if self.name in ['flux-dev-uso', 'flux-dev-umo'] :
inp, height, width = prepare_multi_ip(
@ -253,8 +255,8 @@ class model_factory:
if image_mask is not None:
from shared.utils.utils import convert_image_to_tensor
img_msk_rebuilt = inp["img_msk_rebuilt"]
img= convert_image_to_tensor(image_guide)
x = img.squeeze(2) * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt
img= input_frames.squeeze(1).unsqueeze(0) # convert_image_to_tensor(image_guide)
x = img * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt
x = x.clamp(-1, 1)
x = x.transpose(0, 1)

View File

@ -865,7 +865,7 @@ class HunyuanVideoSampler(Inference):
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
else:
if input_frames != None:
target_height, target_width = input_frames.shape[-3:-1]
target_height, target_width = input_frames.shape[-2:]
elif input_video != None:
target_height, target_width = input_video.shape[-2:]
@ -894,9 +894,10 @@ class HunyuanVideoSampler(Inference):
pixel_value_bg = input_video.unsqueeze(0)
pixel_value_mask = torch.zeros_like(input_video).unsqueeze(0)
if input_frames != None:
pixel_value_video_bg = input_frames.permute(-1,0,1,2).unsqueeze(0).float()
pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float()
pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.)
pixel_value_video_bg = input_frames.unsqueeze(0) #.permute(-1,0,1,2).unsqueeze(0).float()
# pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.)
# pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float()
pixel_value_video_mask = input_masks.repeat(3,1,1,1).unsqueeze(0)
if input_video != None:
pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2)
pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2)
@ -908,10 +909,11 @@ class HunyuanVideoSampler(Inference):
if pixel_value_bg.shape[2] < frame_num:
padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] + list(pixel_value_bg.shape[3:])
pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2)
pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
# pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 1, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample()
pixel_value_mask = pixel_value_mask.div_(127.5).add_(-1.)
pixel_value_mask = pixel_value_mask.mul_(2).add_(-1.) # unmasked pixels is -1 (no 0 as usual) and masked is 1
mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample()
bg_latents = torch.cat([bg_latents, mask_latents], dim=1)
bg_latents.mul_(self.vae.config.scaling_factor)

View File

@ -35,6 +35,8 @@ class family_handler():
"selection": ["", "A", "NA", "XA", "XNA"],
}
extra_model_def["extra_control_frames"] = 1
extra_model_def["dont_cat_preguide"]= True
return extra_model_def
@staticmethod

View File

@ -17,7 +17,7 @@ class family_handler():
("Default", "default"),
("Lightning", "lightning")],
"guidance_max_phases" : 1,
"lock_image_refs_ratios": True,
"fit_into_canvas_image_refs": 0,
}
if base_model_type in ["qwen_image_edit_20B"]:

View File

@ -17,7 +17,7 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
from diffusers import FlowMatchEulerDiscreteScheduler
from .pipeline_qwenimage import QwenImagePipeline
from PIL import Image
from shared.utils.utils import calculate_new_dimensions
from shared.utils.utils import calculate_new_dimensions, convert_tensor_to_image
def stitch_images(img1, img2):
# Resize img2 to match img1's height
@ -103,8 +103,8 @@ class model_factory():
n_prompt = None,
sampling_steps: int = 20,
input_ref_images = None,
image_guide= None,
image_mask= None,
input_frames= None,
input_masks= None,
width= 832,
height=480,
guide_scale: float = 4,
@ -179,8 +179,10 @@ class model_factory():
if n_prompt is None or len(n_prompt) == 0:
n_prompt= "text, watermark, copyright, blurry, low resolution"
if image_guide is not None:
input_ref_images = [image_guide]
image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True)
if input_frames is not None:
input_ref_images = [convert_tensor_to_image(input_frames) ]
elif input_ref_images is not None:
# image stiching method
stiched = input_ref_images[0]

View File

@ -0,0 +1,143 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import numbers
from peft import LoraConfig
def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"):
target_modules = []
for name, module in transformer.named_modules():
if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear):
target_modules.append(name)
transformer_lora_config = LoraConfig(
r=rank,
lora_alpha=alpha,
init_lora_weights=init_lora_weights,
target_modules=target_modules,
)
return transformer_lora_config
class TensorList(object):
def __init__(self, tensors):
"""
tensors: a list of torch.Tensor objects. No need to have uniform shape.
"""
assert isinstance(tensors, (list, tuple))
assert all(isinstance(u, torch.Tensor) for u in tensors)
assert len(set([u.ndim for u in tensors])) == 1
assert len(set([u.dtype for u in tensors])) == 1
assert len(set([u.device for u in tensors])) == 1
self.tensors = tensors
def to(self, *args, **kwargs):
return TensorList([u.to(*args, **kwargs) for u in self.tensors])
def size(self, dim):
assert dim == 0, 'only support get the 0th size'
return len(self.tensors)
def pow(self, *args, **kwargs):
return TensorList([u.pow(*args, **kwargs) for u in self.tensors])
def squeeze(self, dim):
assert dim != 0
if dim > 0:
dim -= 1
return TensorList([u.squeeze(dim) for u in self.tensors])
def type(self, *args, **kwargs):
return TensorList([u.type(*args, **kwargs) for u in self.tensors])
def type_as(self, other):
assert isinstance(other, (torch.Tensor, TensorList))
if isinstance(other, torch.Tensor):
return TensorList([u.type_as(other) for u in self.tensors])
else:
return TensorList([u.type(other.dtype) for u in self.tensors])
@property
def dtype(self):
return self.tensors[0].dtype
@property
def device(self):
return self.tensors[0].device
@property
def ndim(self):
return 1 + self.tensors[0].ndim
def __getitem__(self, index):
return self.tensors[index]
def __len__(self):
return len(self.tensors)
def __add__(self, other):
return self._apply(other, lambda u, v: u + v)
def __radd__(self, other):
return self._apply(other, lambda u, v: v + u)
def __sub__(self, other):
return self._apply(other, lambda u, v: u - v)
def __rsub__(self, other):
return self._apply(other, lambda u, v: v - u)
def __mul__(self, other):
return self._apply(other, lambda u, v: u * v)
def __rmul__(self, other):
return self._apply(other, lambda u, v: v * u)
def __floordiv__(self, other):
return self._apply(other, lambda u, v: u // v)
def __truediv__(self, other):
return self._apply(other, lambda u, v: u / v)
def __rfloordiv__(self, other):
return self._apply(other, lambda u, v: v // u)
def __rtruediv__(self, other):
return self._apply(other, lambda u, v: v / u)
def __pow__(self, other):
return self._apply(other, lambda u, v: u ** v)
def __rpow__(self, other):
return self._apply(other, lambda u, v: v ** u)
def __neg__(self):
return TensorList([-u for u in self.tensors])
def __iter__(self):
for tensor in self.tensors:
yield tensor
def __repr__(self):
return 'TensorList: \n' + repr(self.tensors)
def _apply(self, other, op):
if isinstance(other, (list, tuple, TensorList)) or (
isinstance(other, torch.Tensor) and (
other.numel() > 1 or other.ndim > 1
)
):
assert len(other) == len(self.tensors)
return TensorList([op(u, v) for u, v in zip(self.tensors, other)])
elif isinstance(other, numbers.Number) or (
isinstance(other, torch.Tensor) and (
other.numel() == 1 and other.ndim <= 1
)
):
return TensorList([op(u, other) for u in self.tensors])
else:
raise TypeError(
f'unsupported operand for *: "TensorList" and "{type(other)}"'
)

View File

@ -0,0 +1,382 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from torch import nn
import torch
from typing import Tuple, Optional
from einops import rearrange
import torch.nn.functional as F
import math
from shared.attention import pay_attention
MEMORY_LAYOUT = {
"flash": (
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
lambda x: x,
),
"torch": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
"vanilla": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
}
def attention(
q,
k,
v,
mode="torch",
drop_rate=0,
attn_mask=None,
causal=False,
max_seqlen_q=None,
batch_size=1,
):
"""
Perform QKV self attention.
Args:
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
drop_rate (float): Dropout rate in attention map. (default: 0)
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
(default: None)
causal (bool): Whether to use causal attention. (default: False)
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into q.
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into kv.
max_seqlen_q (int): The maximum sequence length in the batch of q.
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
Returns:
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
"""
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
if mode == "torch":
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
elif mode == "flash":
x = flash_attn_func(
q,
k,
v,
)
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
elif mode == "vanilla":
scale_factor = 1 / math.sqrt(q.size(-1))
b, a, s, _ = q.shape
s1 = k.size(2)
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
if causal:
# Only applied to self attention
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(q.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn = (q @ k.transpose(-2, -1)) * scale_factor
attn += attn_bias
attn = attn.softmax(dim=-1)
attn = torch.dropout(attn, p=drop_rate, train=True)
x = attn @ v
else:
raise NotImplementedError(f"Unsupported attention mode: {mode}")
x = post_attn_layout(x)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
class CausalConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size - 1, 0) # T
self.time_causal_padding = padding
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class FaceEncoder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.num_heads = num_heads
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.act = nn.SiLU()
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
self.out_proj = nn.Linear(1024, hidden_dim)
self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
def forward(self, x):
x = rearrange(x, "b t c -> b c t")
b, c, t = x.shape
x = self.conv1_local(x)
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv2(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, "b t c -> b c t")
x = self.conv3(x)
x = rearrange(x, "b c t -> b t c")
x = self.norm3(x)
x = self.act(x)
x = self.out_proj(x)
x = rearrange(x, "(b n) t c -> b t n c", b=b)
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
x = torch.cat([x, padding], dim=-2)
x_local = x.clone()
return x_local
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
def get_norm_layer(norm_layer):
"""
Get the normalization layer.
Args:
norm_layer (str): The type of normalization layer.
Returns:
norm_layer (nn.Module): The normalization layer.
"""
if norm_layer == "layer":
return nn.LayerNorm
elif norm_layer == "rms":
return RMSNorm
else:
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
class FaceAdapter(nn.Module):
def __init__(
self,
hidden_dim: int,
heads_num: int,
qk_norm: bool = True,
qk_norm_type: str = "rms",
num_adapter_layers: int = 1,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.hidden_size = hidden_dim
self.heads_num = heads_num
self.fuser_blocks = nn.ModuleList(
[
FaceBlock(
self.hidden_size,
self.heads_num,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
**factory_kwargs,
)
for _ in range(num_adapter_layers)
]
)
def forward(
self,
x: torch.Tensor,
motion_embed: torch.Tensor,
idx: int,
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
) -> torch.Tensor:
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
class FaceBlock(nn.Module):
def __init__(
self,
hidden_size: int,
heads_num: int,
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
self.scale = qk_scale or head_dim**-0.5
self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
)
self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
def forward(
self,
x: torch.Tensor,
motion_vec: torch.Tensor,
motion_mask: Optional[torch.Tensor] = None,
use_context_parallel=False,
) -> torch.Tensor:
B, T, N, C = motion_vec.shape
T_comp = T
x_motion = self.pre_norm_motion(motion_vec)
x_feat = self.pre_norm_feat(x)
kv = self.linear1_kv(x_motion)
q = self.linear1_q(x_feat)
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
# Apply QK-Norm if needed.
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
k = rearrange(k, "B L N H D -> (B L) N H D")
v = rearrange(v, "B L N H D -> (B L) N H D")
if use_context_parallel:
q = gather_forward(q, dim=1)
q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)
# Compute attention.
# Size([batches, tokens, heads, head_features])
qkv_list = [q, k, v]
del q,k,v
attn = pay_attention(qkv_list)
# attn = attention(
# q,
# k,
# v,
# max_seqlen_q=q.shape[1],
# batch_size=q.shape[0],
# )
attn = attn.reshape(*attn.shape[:2], -1)
attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
# if use_context_parallel:
# attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()]
output = self.linear2(attn)
if motion_mask is not None:
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
return output

View File

@ -0,0 +1,31 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import types
from copy import deepcopy
from einops import rearrange
from typing import List
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn as nn
def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
pose_latents = self.pose_patch_embedding(pose_latents)
x[:, :, 1:] += pose_latents
b,c,T,h,w = face_pixel_values.shape
face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
encode_bs = 8
face_pixel_values_tmp = []
for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
motion_vec = torch.cat(face_pixel_values_tmp)
motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
motion_vec = self.face_encoder(motion_vec)
B, L, H, C = motion_vec.shape
pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
return x, motion_vec

View File

@ -0,0 +1,308 @@
# Modified from ``https://github.com/wyhsirius/LIA``
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
def custom_qr(input_tensor):
original_dtype = input_tensor.dtype
if original_dtype in [torch.bfloat16, torch.float16]:
q, r = torch.linalg.qr(input_tensor.to(torch.float32))
return q.to(original_dtype), r.to(original_dtype)
return torch.linalg.qr(input_tensor)
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
return F.leaky_relu(input + bias, negative_slope) * scale
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
_, minor, in_h, in_w = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, minor, in_h, 1, in_w, 1)
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
return out[:, :, ::down_y, ::down_x]
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
def make_kernel(k):
k = torch.tensor(k, dtype=torch.float32)
if k.ndim == 1:
k = k[None, :] * k[:, None]
k /= k.sum()
return k
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
super().__init__()
self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
return out
class Blur(nn.Module):
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
kernel = make_kernel(kernel)
if upsample_factor > 1:
kernel = kernel * (upsample_factor ** 2)
self.register_buffer('kernel', kernel)
self.pad = pad
def forward(self, input):
return upfirdn2d(input, self.kernel, pad=self.pad)
class ScaledLeakyReLU(nn.Module):
def __init__(self, negative_slope=0.2):
super().__init__()
self.negative_slope = negative_slope
def forward(self, input):
return F.leaky_relu(input, negative_slope=self.negative_slope)
class EqualConv2d(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
self.stride = stride
self.padding = padding
if bias:
self.bias = nn.Parameter(torch.zeros(out_channel))
else:
self.bias = None
def forward(self, input):
return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
def __repr__(self):
return (
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
)
class EqualLinear(nn.Module):
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.activation:
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
else:
out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
return out
def __repr__(self):
return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
class ConvLayer(nn.Sequential):
def __init__(
self,
in_channel,
out_channel,
kernel_size,
downsample=False,
blur_kernel=[1, 3, 3, 1],
bias=True,
activate=True,
):
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
bias=bias and not activate))
if activate:
if bias:
layers.append(FusedLeakyReLU(out_channel))
else:
layers.append(ScaledLeakyReLU(0.2))
super().__init__(*layers)
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
super().__init__()
self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
skip = self.skip(input)
out = (out + skip) / math.sqrt(2)
return out
class EncoderApp(nn.Module):
def __init__(self, size, w_dim=512):
super(EncoderApp, self).__init__()
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256,
128: 128,
256: 64,
512: 32,
1024: 16
}
self.w_dim = w_dim
log_size = int(math.log(size, 2))
self.convs = nn.ModuleList()
self.convs.append(ConvLayer(3, channels[size], 1))
in_channel = channels[size]
for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)]
self.convs.append(ResBlock(in_channel, out_channel))
in_channel = out_channel
self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
def forward(self, x):
res = []
h = x
for conv in self.convs:
h = conv(h)
res.append(h)
return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
class Encoder(nn.Module):
def __init__(self, size, dim=512, dim_motion=20):
super(Encoder, self).__init__()
# appearance netmork
self.net_app = EncoderApp(size, dim)
# motion network
fc = [EqualLinear(dim, dim)]
for i in range(3):
fc.append(EqualLinear(dim, dim))
fc.append(EqualLinear(dim, dim_motion))
self.fc = nn.Sequential(*fc)
def enc_app(self, x):
h_source = self.net_app(x)
return h_source
def enc_motion(self, x):
h, _ = self.net_app(x)
h_motion = self.fc(h)
return h_motion
class Direction(nn.Module):
def __init__(self, motion_dim):
super(Direction, self).__init__()
self.weight = nn.Parameter(torch.randn(512, motion_dim))
def forward(self, input):
weight = self.weight + 1e-8
Q, R = custom_qr(weight)
if input is None:
return Q
else:
input_diag = torch.diag_embed(input) # alpha, diagonal matrix
out = torch.matmul(input_diag, Q.T)
out = torch.sum(out, dim=1)
return out
class Synthesis(nn.Module):
def __init__(self, motion_dim):
super(Synthesis, self).__init__()
self.direction = Direction(motion_dim)
class Generator(nn.Module):
def __init__(self, size, style_dim=512, motion_dim=20):
super().__init__()
self.enc = Encoder(size, style_dim, motion_dim)
self.dec = Synthesis(motion_dim)
def get_motion(self, img):
#motion_feat = self.enc.enc_motion(img)
# motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
with torch.cuda.amp.autocast(dtype=torch.float32):
motion_feat = self.enc.enc_motion(img)
motion = self.dec.direction(motion_feat)
return motion

View File

@ -203,10 +203,7 @@ class WanAny2V:
self.use_timestep_transform = True
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = None):
if ref_images is None:
ref_images = [None] * len(frames)
else:
assert len(frames) == len(ref_images)
ref_images = [ref_images] * len(frames)
if masks is None:
latents = self.vae.encode(frames, tile_size = tile_size)
@ -238,11 +235,7 @@ class WanAny2V:
return cat_latents
def vace_encode_masks(self, masks, ref_images=None):
if ref_images is None:
ref_images = [None] * len(masks)
else:
assert len(masks) == len(ref_images)
ref_images = [ref_images] * len(masks)
result_masks = []
for mask, refs in zip(masks, ref_images):
c, depth, height, width = mask.shape
@ -270,79 +263,6 @@ class WanAny2V:
result_masks.append(mask)
return result_masks
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
image_sizes = []
trim_video_guide = len(keep_video_guide_frames)
def conv_tensor(t, device):
return t.float().div_(127.5).add_(-1).permute(3, 0, 1, 2).to(device)
for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
num_frames = total_frames - prepend_count
num_frames = min(num_frames, trim_video_guide) if trim_video_guide > 0 and sub_src_video != None else num_frames
if sub_src_mask is not None and sub_src_video is not None:
src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
src_mask[i] = conv_tensor(sub_src_mask[:num_frames], device)
# src_video is [-1, 1] (at this function output), 0 = inpainting area (in fact 127 in [0, 255])
# src_mask is [-1, 1] (at this function output), 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
src_mask[i] = torch.cat( [torch.full_like(sub_pre_src_video, -1.0), src_mask[i]] ,1)
src_video_shape = src_video[i].shape
if src_video_shape[1] != total_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.clamp((src_mask[i][:, :, :, :] + 1) / 2, min=0, max=1)
image_sizes.append(src_video[i].shape[2:])
elif sub_src_video is None:
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1)
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1)
else:
src_video[i] = torch.zeros((3, total_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size)
else:
src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
src_video_shape = src_video[i].shape
if src_video_shape[1] != total_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
image_sizes.append(src_video[i].shape[2:])
for k, keep in enumerate(keep_video_guide_frames):
if not keep:
pos = prepend_count + k
src_video[i][:, pos:pos+1] = 0
src_mask[i][:, pos:pos+1] = 1
for k, frame in enumerate(inject_frames):
if frame != None:
pos = prepend_count + k
src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True)
self.background_mask = None
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
image_size = image_sizes[i]
for j, ref_img in enumerate(ref_images):
if ref_img is not None and not torch.is_tensor(ref_img):
if j==0 and any_background_ref:
if self.background_mask == None: self.background_mask = [None] * len(src_ref_images)
src_ref_images[i][j], self.background_mask[i] = fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True)
else:
src_ref_images[i][j], _ = fit_image_into_canvas(ref_img, image_size, 1, device)
if self.background_mask != None:
self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref
return src_video, src_mask, src_ref_images
def get_vae_latents(self, ref_images, device, tile_size= 0):
ref_vae_latents = []
@ -369,7 +289,9 @@ class WanAny2V:
def generate(self,
input_prompt,
input_frames= None,
input_frames2= None,
input_masks = None,
input_masks2 = None,
input_ref_images = None,
input_ref_masks = None,
input_faces = None,
@ -615,21 +537,22 @@ class WanAny2V:
pose_pixels = input_frames * input_masks
input_masks = 1. - input_masks
pose_pixels -= input_masks
save_video(pose_pixels, "pose.mp4")
pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0)
input_frames = input_frames * input_masks
if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames
if prefix_frames_count > 0:
input_frames[:, :prefix_frames_count] = input_video
input_masks[:, :prefix_frames_count] = 1
save_video(input_frames, "input_frames.mp4")
save_video(input_masks, "input_masks.mp4", value_range=(0,1))
# save_video(pose_pixels, "pose.mp4")
# save_video(input_frames, "input_frames.mp4")
# save_video(input_masks, "input_masks.mp4", value_range=(0,1))
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device)
msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device)
msk = torch.concat([msk_ref, msk_control], dim=1)
clip_image_start = image_ref = convert_image_to_tensor(input_ref_images[0]).to(self.device)
lat_y = torch.concat(self.vae.encode([image_ref.unsqueeze(1).to(self.device), input_frames.to(self.device)], VAE_tile_size), dim=1)
image_ref = input_ref_images[0].to(self.device)
clip_image_start = image_ref.squeeze(1)
lat_y = torch.concat(self.vae.encode([image_ref, input_frames.to(self.device)], VAE_tile_size), dim=1)
y = torch.concat([msk, lat_y])
kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)})
lat_y = msk = msk_control = msk_ref = pose_pixels = None
@ -701,12 +624,11 @@ class WanAny2V:
# Phantom
if phantom:
input_ref_images_neg = None
if input_ref_images != None: # Phantom Ref images
input_ref_images = self.get_vae_latents(input_ref_images, self.device)
input_ref_images_neg = torch.zeros_like(input_ref_images)
ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0
trim_frames = input_ref_images.shape[1]
lat_input_ref_images_neg = None
if input_ref_images is not None: # Phantom Ref images
lat_input_ref_images = self.get_vae_latents(input_ref_images, self.device)
lat_input_ref_images_neg = torch.zeros_like(lat_input_ref_images)
ref_images_count = trim_frames = lat_input_ref_images.shape[1]
if ti2v:
if input_video is None:
@ -721,25 +643,23 @@ class WanAny2V:
# Vace
if vace :
# vace context encode
input_frames = [u.to(self.device) for u in input_frames]
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
input_masks = [u.to(self.device) for u in input_masks]
input_frames = [input_frames.to(self.device)] +([] if input_frames2 is None else [input_frames2.to(self.device)])
input_masks = [input_masks.to(self.device)] + ([] if input_masks2 is None else [input_masks2.to(self.device)])
input_ref_images = None if input_ref_images is None else [ u.to(self.device) for u in input_ref_images]
input_ref_masks = None if input_ref_masks is None else [ None if u is None else u.to(self.device) for u in input_ref_masks]
ref_images_before = True
if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask]
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
m0 = self.vace_encode_masks(input_masks, input_ref_images)
if self.background_mask != None:
color_reference_frame = input_ref_images[0][0].clone()
zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size )
mbg = self.vace_encode_masks(self.background_mask, None)
if input_ref_masks is not None and len(input_ref_masks) > 0 and input_ref_masks[0] is not None:
color_reference_frame = input_ref_images[0].clone()
zbg = self.vace_encode_frames( input_ref_images[:1] * len(input_frames), None, masks=input_ref_masks[0], tile_size = VAE_tile_size )
mbg = self.vace_encode_masks(input_ref_masks[:1] * len(input_frames), None)
for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg):
zz0[:, 0:1] = zzbg
mm0[:, 0:1] = mmbg
self.background_mask = zz0 = mm0 = zzbg = mmbg = None
z = self.vace_latent(z0, m0)
ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0
zz0 = mm0 = zzbg = mmbg = None
z = [torch.cat([zz, mm], dim=0) for zz, mm in zip(z0, m0)]
ref_images_count = len(input_ref_images) if input_ref_images is not None and input_ref_images is not None else 0
context_scale = context_scale if context_scale != None else [1.0] * len(z)
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count })
if overlapped_latents != None :
@ -747,15 +667,8 @@ class WanAny2V:
extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0)
if prefix_frames_count > 0:
color_reference_frame = input_frames[0][:, prefix_frames_count -1:prefix_frames_count].clone()
target_shape = list(z0[0].shape)
target_shape[0] = int(target_shape[0] / 2)
lat_h, lat_w = target_shape[-2:]
height = self.vae_stride[1] * lat_h
width = self.vae_stride[2] * lat_w
else:
target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2])
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, lat_h, lat_w)
if multitalk:
if audio_proj is None:
@ -860,7 +773,9 @@ class WanAny2V:
apg_norm_threshold = 55
text_momentumbuffer = MomentumBuffer(apg_momentum)
audio_momentumbuffer = MomentumBuffer(apg_momentum)
input_frames = input_frames2 = input_masks =input_masks2 = input_video = input_ref_images = input_ref_masks = pre_video_frame = None
gc.collect()
torch.cuda.empty_cache()
# denoising
trans = self.model
@ -878,7 +793,7 @@ class WanAny2V:
kwargs.update({"t": timestep, "current_step": start_step_no + i})
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step:
if denoising_strength < 1 and i <= injection_denoising_step:
sigma = t / 1000
noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
if inject_from_start:
@ -912,8 +827,8 @@ class WanAny2V:
any_guidance = guide_scale != 1
if phantom:
gen_args = {
"x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 +
[ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]),
"x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 +
[ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]),
"context": [context, context_null, context_null] ,
}
elif fantasy:

View File

@ -21,6 +21,7 @@ class family_handler():
extra_model_def["fps"] =fps
extra_model_def["frames_minimum"] = 17
extra_model_def["frames_steps"] = 20
extra_model_def["latent_size"] = 4
extra_model_def["sliding_window"] = True
extra_model_def["skip_layer_guidance"] = True
extra_model_def["tea_cache"] = True

View File

@ -114,7 +114,6 @@ class family_handler():
"tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels),
"mag_cache" : True,
"keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
"convert_image_guide_to_video" : True,
"sample_solvers":[
("unipc", "unipc"),
("euler", "euler"),
@ -175,6 +174,8 @@ class family_handler():
extra_model_def["forced_guide_mask_inputs"] = True
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)"
extra_model_def["background_ref_outpainted"] = False
extra_model_def["return_image_refs_tensor"] = True
extra_model_def["guide_inpaint_color"] = 0
@ -196,15 +197,15 @@ class family_handler():
"letters_filter": "KFI",
}
extra_model_def["lock_image_refs_ratios"] = True
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames"
extra_model_def["video_guide_outpainting"] = [0,1]
extra_model_def["pad_guide_video"] = True
extra_model_def["guide_inpaint_color"] = 127.5
extra_model_def["forced_guide_mask_inputs"] = True
extra_model_def["return_image_refs_tensor"] = True
if base_model_type in ["standin"]:
extra_model_def["lock_image_refs_ratios"] = True
extra_model_def["fit_into_canvas_image_refs"] = 0
extra_model_def["image_ref_choices"] = {
"choices": [
("No Reference Image", ""),
@ -480,6 +481,7 @@ class family_handler():
ui_defaults.update({
"video_prompt_type": "PVBXAKI",
"mask_expand": 20,
"audio_prompt_type_value": "R",
})
if text_oneframe_overlap(base_model_type):

View File

@ -32,6 +32,14 @@ def seed_everything(seed: int):
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
def has_video_file_extension(filename):
extension = os.path.splitext(filename)[-1].lower()
return extension in [".mp4"]
def has_image_file_extension(filename):
extension = os.path.splitext(filename)[-1].lower()
return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"]
def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ):
import math
@ -94,7 +102,7 @@ def get_video_info(video_path):
return fps, width, height, frame_count
def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, return_PIL = True) -> torch.Tensor:
def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, target_fps = None, return_PIL = True) -> torch.Tensor:
"""Extract nth frame from video as PyTorch tensor normalized to [-1, 1]."""
cap = cv2.VideoCapture(file_name)
@ -102,6 +110,9 @@ def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool
raise ValueError(f"Cannot open video: {file_name}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = round(cap.get(cv2.CAP_PROP_FPS))
if target_fps is not None:
frame_no = round(target_fps * frame_no /fps)
# Handle out of bounds
if frame_no >= total_frames or frame_no < 0:
@ -175,10 +186,15 @@ def remove_background(img, session=None):
def convert_image_to_tensor(image):
return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0)
def convert_tensor_to_image(t, frame_no = 0):
def convert_tensor_to_image(t, frame_no = 0, mask_levels = False):
if len(t.shape) == 4:
t = t[:, frame_no]
return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
if t.shape[0]== 1:
t = t.expand(3,-1,-1)
if mask_levels:
return Image.fromarray(t.clone().mul_(255).permute(1,2,0).to(torch.uint8).cpu().numpy())
else:
return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
def save_image(tensor_image, name, frame_no = -1):
convert_tensor_to_image(tensor_image, frame_no).save(name)
@ -257,7 +273,7 @@ def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fi
image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
return image, new_height, new_width
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5 ):
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5, return_tensor = False ):
if rm_background:
session = new_session()
@ -266,7 +282,7 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
for i, img in enumerate(img_list):
width, height = img.size
resized_mask = None
if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2:
if any_background_ref == 1 and i==0 or any_background_ref == 2:
if outpainting_dims is not None and background_ref_outpainted:
resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True)
elif img.size != (budget_width, budget_height):
@ -291,7 +307,10 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) :
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
if return_tensor:
output_list.append(convert_image_to_tensor(resized_image).unsqueeze(1))
else:
output_list.append(resized_image)
output_mask_list.append(resized_mask)
return output_list, output_mask_list
@ -346,47 +365,46 @@ def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu
return ref_img.to(device), canvas
def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, extract_guide_from_window_start = False, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None ):
def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None, device ="cpu"):
src_videos, src_masks = [], []
inpaint_color = guide_inpaint_color/127.5 - 1
prepend_count = pre_video_guide.shape[1] if not extract_guide_from_window_start and pre_video_guide is not None else 0
inpaint_color_compressed = guide_inpaint_color/127.5 - 1
prepend_count = pre_video_guide.shape[1] if pre_video_guide is not None else 0
for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)):
src_video = src_mask = None
if cur_video_guide is not None:
src_video = cur_video_guide.permute(3, 0, 1, 2).float().div_(127.5).sub_(1.) # c, f, h, w
if cur_video_mask is not None and any_mask:
src_mask = cur_video_mask.permute(3, 0, 1, 2).float().div_(255)[0:1] # c, f, h, w
if pre_video_guide is not None and not extract_guide_from_window_start:
src_video, src_mask = cur_video_guide, cur_video_mask
if pre_video_guide is not None:
src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1)
if any_mask:
src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1)
if src_video is None:
if any_guide_padding:
src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color, dtype = torch.float, device= "cpu")
if any_mask:
src_mask = torch.zeros_like(src_video[0:1])
elif src_video.shape[1] < current_video_length:
if any_guide_padding:
src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color, dtype = src_video.dtype, device= src_video.device) ], dim=1)
if cur_video_mask is not None and any_mask:
src_mask = torch.cat([src_mask, torch.full( (1, current_video_length - src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1)
if any_guide_padding:
if src_video is None:
src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color_compressed, dtype = torch.float, device= device)
elif src_video.shape[1] < current_video_length:
src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color_compressed, dtype = src_video.dtype, device= src_video.device) ], dim=1)
elif src_video is not None:
new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1
src_video = src_video[:, :new_num_frames]
if any_mask and src_video is not None:
if src_mask is None:
src_mask = torch.ones_like(src_video[:1])
elif src_mask.shape[1] < src_video.shape[1]:
src_mask = torch.cat([src_mask, torch.full( (1, src_video.shape[1]- src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1)
else:
new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1
src_video = src_video[:, :new_num_frames]
if any_mask:
src_mask = src_mask[:, :new_num_frames]
src_mask = src_mask[:, :src_video.shape[1]]
for k, keep in enumerate(keep_video_guide_frames):
if not keep:
pos = prepend_count + k
src_video[:, pos:pos+1] = inpaint_color
src_mask[:, pos:pos+1] = 1
for k, frame in enumerate(inject_frames):
if frame != None:
pos = prepend_count + k
src_video[:, pos:pos+1], src_mask[:, pos:pos+1] = fit_image_into_canvas(frame, image_size, inpaint_color, device, True, outpainting_dims, return_mask= True)
if src_video is not None :
for k, keep in enumerate(keep_video_guide_frames):
if not keep:
pos = prepend_count + k
src_video[:, pos:pos+1] = inpaint_color_compressed
if any_mask: src_mask[:, pos:pos+1] = 1
for k, frame in enumerate(inject_frames):
if frame != None:
pos = prepend_count + k
src_video[:, pos:pos+1], msk = fit_image_into_canvas(frame, image_size, guide_inpaint_color, device, True, outpainting_dims, return_mask= any_mask)
if any_mask: src_mask[:, pos:pos+1] = msk
src_videos.append(src_video)
src_masks.append(src_mask)
return src_videos, src_masks

396
wgp.py
View File

@ -24,6 +24,7 @@ from shared.utils import notification_sound
from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers
from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background, rgb_bw_to_rgba_mask
from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions
from shared.utils.utils import has_video_file_extension, has_image_file_extension
from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image
from shared.utils.audio_video import save_image_metadata, read_image_metadata
from shared.match_archi import match_nvidia_architecture
@ -62,7 +63,7 @@ AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.6.0"
WanGP_version = "8.61"
WanGP_version = "8.7"
settings_version = 2.35
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@ -1942,7 +1943,8 @@ def get_model_min_frames_and_step(model_type):
mode_def = get_model_def(model_type)
frames_minimum = mode_def.get("frames_minimum", 5)
frames_steps = mode_def.get("frames_steps", 4)
return frames_minimum, frames_steps
latent_size = mode_def.get("latent_size", frames_steps)
return frames_minimum, frames_steps, latent_size
def get_model_fps(model_type):
mode_def = get_model_def(model_type)
@ -3459,7 +3461,7 @@ def select_video(state, input_file_list, event_data: gr.EventData):
if len(video_other_prompts) >0 :
values += [video_other_prompts]
labels += ["Other Prompts"]
if len(video_outpainting) >0 and any_letters(video_image_prompt_type, "VFK"):
if len(video_outpainting) >0:
values += [video_outpainting]
labels += ["Outpainting"]
video_sample_solver = configs.get("sample_solver", "")
@ -3532,6 +3534,11 @@ def convert_image(image):
return cast(Image, ImageOps.exif_transpose(image))
def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='torch'):
if isinstance(video_in, str) and has_image_file_extension(video_in):
video_in = Image.open(video_in)
if isinstance(video_in, Image.Image):
return torch.from_numpy(np.array(video_in).astype(np.uint8)).unsqueeze(0)
from shared.utils.utils import resample
import decord
@ -3653,19 +3660,22 @@ def get_preprocessor(process_type, inpaint_color):
def process_images_multithread(image_processor, items, process_type, wrap_in_list = True, max_workers: int = os.cpu_count()/ 2) :
if not items:
return []
max_workers = 11
import concurrent.futures
start_time = time.time()
# print(f"Preprocessus:{process_type} started")
if process_type in ["prephase", "upsample"]:
if wrap_in_list :
items = [ [img] for img in items]
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)}
results = [None] * len(items)
for future in concurrent.futures.as_completed(futures):
idx = futures[future]
results[idx] = future.result()
if max_workers == 1:
results = [image_processor(img) for img in items]
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)}
results = [None] * len(items)
for future in concurrent.futures.as_completed(futures):
idx = futures[future]
results[idx] = future.result()
if wrap_in_list:
results = [ img[0] for img in results]
@ -3677,55 +3687,6 @@ def process_images_multithread(image_processor, items, process_type, wrap_in_lis
return results
def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canvas = False, fit_crop = False, block_size= 16, expand_scale = 2, outpainting_dims = None, inpaint_color = 127):
frame_width, frame_height = input_image.size
if fit_crop:
input_image = rescale_and_crop(input_image, width, height)
if input_mask is not None:
input_mask = rescale_and_crop(input_mask, width, height)
return input_image, input_mask
if outpainting_dims != None:
if fit_canvas != None:
frame_height, frame_width = get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims)
else:
frame_height, frame_width = height, width
if fit_canvas != None:
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size)
if outpainting_dims != None:
final_height, final_width = height, width
height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
if fit_canvas != None or outpainting_dims != None:
input_image = input_image.resize((width, height), resample=Image.Resampling.LANCZOS)
if input_mask is not None:
input_mask = input_mask.resize((width, height), resample=Image.Resampling.LANCZOS)
if expand_scale != 0 and input_mask is not None:
kernel_size = abs(expand_scale)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
op_expand = cv2.dilate if expand_scale > 0 else cv2.erode
input_mask = np.array(input_mask)
input_mask = op_expand(input_mask, kernel, iterations=3)
input_mask = Image.fromarray(input_mask)
if outpainting_dims != None:
inpaint_color = inpaint_color / 127.5-1
image = convert_image_to_tensor(input_image)
full_frame= torch.full( (image.shape[0], final_height, final_width), inpaint_color, dtype= torch.float, device= image.device)
full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = image
input_image = convert_tensor_to_image(full_frame)
if input_mask is not None:
mask = convert_image_to_tensor(input_mask)
full_frame= torch.full( (mask.shape[0], final_height, final_width), 1, dtype= torch.float, device= mask.device)
full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = mask
input_mask = convert_tensor_to_image(full_frame)
return input_image, input_mask
def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512):
if not input_video_path or max_frames <= 0:
return None, None
@ -3780,6 +3741,8 @@ def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_fr
save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None)
return face_tensor
def get_default_workers():
return os.cpu_count()/ 2
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1):
@ -3906,8 +3869,8 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
return (target_frame, frame, mask)
else:
return (target_frame, None, None)
proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False)
max_workers = get_default_workers()
proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False, max_workers=max_workers)
proc_list, proc_list_outside, proc_mask = [None] * len(proc_lists), [None] * len(proc_lists), [None] * len(proc_lists)
for frame_idx, frame_group in enumerate(proc_lists):
proc_list[frame_idx], proc_list_outside[frame_idx], proc_mask[frame_idx] = frame_group
@ -3916,11 +3879,11 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
mask_video = None
if preproc2 != None:
proc_list2 = process_images_multithread(preproc2, proc_list, process_type2)
proc_list2 = process_images_multithread(preproc2, proc_list, process_type2, max_workers=max_workers)
#### to be finished ...or not
proc_list = process_images_multithread(preproc, proc_list, process_type)
proc_list = process_images_multithread(preproc, proc_list, process_type, max_workers=max_workers)
if any_mask:
proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask)
proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask, max_workers=max_workers)
else:
proc_list_outside = proc_mask = len(proc_list) * [None]
@ -3938,7 +3901,7 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
full_frame= torch.full( (final_height, final_width, mask.shape[-1]), 255, dtype= torch.uint8, device= mask.device)
full_frame[margin_top:margin_top+height, margin_left:margin_left+width] = mask
mask = full_frame
masks.append(mask)
masks.append(mask[:, :, 0:1].clone())
else:
masked_frame = processed_img
@ -3958,13 +3921,13 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
proc_list[frame_no] = proc_list_outside[frame_no] = proc_mask[frame_no] = None
if args.save_masks:
from preprocessing.dwpose.pose import save_one_video
saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ]
save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None)
if any_mask:
saved_masks = [mask.cpu().numpy() for mask in masks ]
save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None)
# if args.save_masks:
# from preprocessing.dwpose.pose import save_one_video
# saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ]
# save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None)
# if any_mask:
# saved_masks = [mask.cpu().numpy() for mask in masks ]
# save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None)
preproc = None
preproc_outside = None
gc.collect()
@ -3972,8 +3935,10 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
if pad_frames > 0:
masked_frames = masked_frames[0] * pad_frames + masked_frames
if any_mask: masked_frames = masks[0] * pad_frames + masks
masked_frames = torch.stack(masked_frames).permute(-1,0,1,2).float().div_(127.5).sub_(1.)
masks = torch.stack(masks).permute(-1,0,1,2).float().div_(255) if any_mask else None
return torch.stack(masked_frames), torch.stack(masks) if any_mask else None
return masked_frames, masks
def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size = 16):
@ -4102,7 +4067,7 @@ def perform_spatial_upsampling(sample, spatial_upsampling):
frames_to_upsample = [sample[:, i] for i in range( sample.shape[1]) ]
def upsample_frames(frame):
return resize_lanczos(frame, h, w).unsqueeze(1)
sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False), dim=1)
sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False, max_workers=get_default_workers()), dim=1)
frames_to_upsample = None
return sample
@ -4609,17 +4574,13 @@ def generate_video(
batch_size = 1
temp_filenames_list = []
convert_image_guide_to_video = model_def.get("convert_image_guide_to_video", False)
if convert_image_guide_to_video:
if image_guide is not None and isinstance(image_guide, Image.Image):
video_guide = convert_image_to_video(image_guide)
temp_filenames_list.append(video_guide)
image_guide = None
if image_guide is not None and isinstance(image_guide, Image.Image):
video_guide = image_guide
image_guide = None
if image_mask is not None and isinstance(image_mask, Image.Image):
video_mask = convert_image_to_video(image_mask)
temp_filenames_list.append(video_mask)
image_mask = None
if image_mask is not None and isinstance(image_mask, Image.Image):
video_mask = image_mask
image_mask = None
if model_def.get("no_background_removal", False): remove_background_images_ref = 0
@ -4711,22 +4672,12 @@ def generate_video(
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5)
extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False)
i2v = test_class_i2v(model_type)
diffusion_forcing = "diffusion_forcing" in model_filename
t2v = base_model_type in ["t2v"]
ltxv = "ltxv" in model_filename
vace = test_vace_module(base_model_type)
hunyuan_t2v = "hunyuan_video_720" in model_filename
hunyuan_i2v = "hunyuan_video_i2v" in model_filename
hunyuan_custom = "hunyuan_video_custom" in model_filename
hunyuan_custom_audio = hunyuan_custom and "audio" in model_filename
hunyuan_custom_edit = hunyuan_custom and "edit" in model_filename
hunyuan_avatar = "hunyuan_video_avatar" in model_filename
fantasy = base_model_type in ["fantasy"]
multitalk = model_def.get("multitalk_class", False)
standin = model_def.get("standin_class", False)
infinitetalk = base_model_type in ["infinitetalk"]
animate = base_model_type in ["animate"]
if "B" in audio_prompt_type or "X" in audio_prompt_type:
from models.wan.multitalk.multitalk import parse_speakers_locations
@ -4763,9 +4714,9 @@ def generate_video(
sliding_window_size = current_video_length
reuse_frames = 0
_, latent_size = get_model_min_frames_and_step(model_type)
if diffusion_forcing: latent_size = 4
_, _, latent_size = get_model_min_frames_and_step(model_type)
original_image_refs = image_refs
image_refs = None if image_refs is None else [] + image_refs # work on a copy as it is going to be modified
# image_refs = None
# nb_frames_positions= 0
# Output Video Ratio Priorities:
@ -4889,6 +4840,7 @@ def generate_video(
initial_total_windows = 0
discard_last_frames = sliding_window_discard_last_frames
default_requested_frames_to_generate = current_video_length
nb_frames_positions = 0
if sliding_window:
initial_total_windows= compute_sliding_window_no(default_requested_frames_to_generate, sliding_window_size, discard_last_frames, reuse_frames)
current_video_length = sliding_window_size
@ -4907,7 +4859,7 @@ def generate_video(
if repeat_no >= total_generation: break
repeat_no +=1
gen["repeat_no"] = repeat_no
src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = src_faces = None
src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = None
prefix_video = pre_video_frame = None
source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window
source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before )
@ -4963,7 +4915,6 @@ def generate_video(
return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) )
refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {}
src_ref_images, src_ref_masks = image_refs, None
image_start_tensor = image_end_tensor = None
if window_no == 1 and (video_source is not None or image_start is not None):
if image_start is not None:
@ -5020,7 +4971,7 @@ def generate_video(
if len(pos) > 0:
if pos in ["L", "l"]:
cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length
if cur_end_pos >= last_frame_no and not joker_used:
if cur_end_pos >= last_frame_no-1 and not joker_used:
joker_used = True
cur_end_pos = last_frame_no -1
project_window_no += 1
@ -5036,141 +4987,53 @@ def generate_video(
frames_to_inject[pos] = image_refs[i]
video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
if video_guide is not None:
keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame
extra_control_frames = model_def.get("extra_control_frames", 0)
if extra_control_frames > 0 and aligned_guide_start_frame >= extra_control_frames: guide_frames_extract_start -= extra_control_frames
keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else []
keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ]
guide_frames_extract_count = len(keep_frames_parsed)
# Extract Faces to video
if "B" in video_prompt_type:
send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")])
src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps)
if src_faces is not None and src_faces.shape[1] < current_video_length:
src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1)
if vace or animate:
video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
context_scale = [ control_net_weight]
if "V" in video_prompt_type:
process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
preprocess_type, preprocess_type2 = "raw", None
for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")):
if process_num == 0:
preprocess_type = process_map_video_guide.get(process_letter, "raw")
else:
preprocess_type2 = process_map_video_guide.get(process_letter, None)
status_info = "Extracting " + processes_names[preprocess_type]
extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask])
if len(extra_process_list) == 1:
status_info += " and " + processes_names[extra_process_list[0]]
elif len(extra_process_list) == 2:
status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]]
if preprocess_type2 is not None:
context_scale = [ control_net_weight /2, control_net_weight2 /2]
send_cmd("progress", [0, get_latest_status(state, status_info)])
inpaint_color = 0 if preprocess_type=="pose" else guide_inpaint_color
video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color )
if preprocess_type2 != None:
video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 )
# Sparse Video to Video
sparse_video_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, target_fps = fps, return_PIL = True) if "R" in video_prompt_type else None
if video_guide_processed != None:
if sample_fit_canvas != None:
image_size = video_guide_processed.shape[-3: -1]
sample_fit_canvas = None
refresh_preview["video_guide"] = Image.fromarray(video_guide_processed[0].cpu().numpy())
if video_guide_processed2 != None:
refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())]
if video_mask_processed != None:
refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy())
frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame]
if not vace and (any_letters(video_prompt_type ,"FV") or model_def.get("forced_guide_mask_inputs", False)):
any_mask = True
any_guide_padding = model_def.get("pad_guide_video", False)
from shared.utils.utils import prepare_video_guide_and_mask
src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed, video_guide_processed2],
[video_mask_processed, video_mask_processed2],
pre_video_guide, image_size, current_video_length, latent_size,
any_mask, any_guide_padding, guide_inpaint_color, extract_guide_from_window_start,
keep_frames_parsed, frames_to_inject_parsed , outpainting_dims)
src_video, src_video2 = src_videos
src_mask, src_mask2 = src_masks
if src_video is None:
abort = True
break
if src_faces is not None:
if src_faces.shape[1] < src_video.shape[1]:
src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1)
else:
src_faces = src_faces[:, :src_video.shape[1]]
if args.save_masks:
save_video( src_video, "masked_frames.mp4", fps)
if src_video2 is not None:
save_video( src_video2, "masked_frames2.mp4", fps)
if any_mask:
save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
elif ltxv:
preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
status_info = "Extracting " + processes_names[preprocess_type]
send_cmd("progress", [0, get_latest_status(state, status_info)])
# start one frame ealier to facilitate latents merging later
src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size )
if src_video != None:
src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ]
refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
refresh_preview["video_mask"] = None
src_video = src_video.permute(3, 0, 1, 2)
src_video = src_video.float().div_(127.5).sub_(1.) # c, f, h, w
if sample_fit_canvas != None:
image_size = src_video.shape[-2:]
sample_fit_canvas = None
elif hunyuan_custom_edit:
if "P" in video_prompt_type:
progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")]
# Generic Video Preprocessing
process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
preprocess_type, preprocess_type2 = "raw", None
for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")):
if process_num == 0:
preprocess_type = process_map_video_guide.get(process_letter, "raw")
else:
progress_args = [0, get_latest_status(state,"Extracting Video and Mask")]
preprocess_type2 = process_map_video_guide.get(process_letter, None)
status_info = "Extracting " + processes_names[preprocess_type]
extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask])
if len(extra_process_list) == 1:
status_info += " and " + processes_names[extra_process_list[0]]
elif len(extra_process_list) == 2:
status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]]
context_scale = [control_net_weight /2, control_net_weight2 /2] if preprocess_type2 is not None else [control_net_weight]
if not (preprocess_type == "identity" and preprocess_type2 is None and video_mask is None):send_cmd("progress", [0, get_latest_status(state, status_info)])
inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask == "inpaint" else guide_inpaint_color
video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide if sparse_video_image is None else sparse_video_image, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color, block_size = block_size )
if preprocess_type2 != None:
video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2, block_size = block_size )
send_cmd("progress", progress_args)
src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0)
refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
if src_mask != None:
refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy())
elif "R" in video_prompt_type: # sparse video to video
src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True)
src_image, _, _ = calculate_dimensions_and_resize_image(src_image, image_size[0], image_size[1 ], sample_fit_canvas, fit_crop, block_size = block_size)
refresh_preview["video_guide"] = src_image
src_video = convert_image_to_tensor(src_image).unsqueeze(1)
if sample_fit_canvas != None:
image_size = src_video.shape[-2:]
sample_fit_canvas = None
else: # video to video
video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size= block_size)
if video_guide_processed is None:
src_video = pre_video_guide
else:
if sample_fit_canvas != None:
image_size = video_guide_processed.shape[-3: -1]
sample_fit_canvas = None
src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2)
if pre_video_guide != None:
src_video = torch.cat( [pre_video_guide, src_video], dim=1)
elif image_guide is not None:
new_image_guide, new_image_mask = preprocess_image_with_mask(image_guide, image_mask, image_size[0], image_size[1], fit_canvas = sample_fit_canvas, fit_crop= fit_crop, block_size= block_size, expand_scale = mask_expand, outpainting_dims=outpainting_dims)
if sample_fit_canvas is not None:
image_size = (new_image_guide.size[1], new_image_guide.size[0])
if video_guide_processed is not None and sample_fit_canvas is not None:
image_size = video_guide_processed.shape[-2:]
sample_fit_canvas = None
refresh_preview["image_guide"] = new_image_guide
if new_image_mask is not None:
refresh_preview["image_mask"] = new_image_mask
if window_no == 1 and image_refs is not None and len(image_refs) > 0:
if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) :
@ -5193,44 +5056,67 @@ def generate_video(
refresh_preview["image_refs"] = image_refs
if len(image_refs) > nb_frames_positions:
src_ref_images = image_refs[nb_frames_positions:]
if remove_background_images_ref > 0:
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
# keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested
image_refs[nb_frames_positions:], src_ref_masks = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0],
src_ref_images, src_ref_masks = resize_and_remove_background(src_ref_images , image_size[1], image_size[0],
remove_background_images_ref > 0, any_background_ref,
fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1,
fit_into_canvas= model_def.get("fit_into_canvas_image_refs", 1),
block_size=block_size,
outpainting_dims =outpainting_dims,
background_ref_outpainted = model_def.get("background_ref_outpainted", True) )
refresh_preview["image_refs"] = image_refs
background_ref_outpainted = model_def.get("background_ref_outpainted", True),
return_tensor= model_def.get("return_image_refs_tensor", False) )
if vace :
image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2],
[video_mask_processed] if video_guide_processed2 == None else [video_mask_processed, video_mask_processed2],
[image_refs_copy] if video_guide_processed2 == None else [image_refs_copy, image_refs_copy],
current_video_length, image_size = image_size, device ="cpu",
keep_video_guide_frames=keep_frames_parsed,
pre_src_video = [pre_video_guide] if video_guide_processed2 == None else [pre_video_guide, pre_video_guide],
inject_frames= frames_to_inject_parsed,
outpainting_dims = outpainting_dims,
any_background_ref = any_background_ref
)
if len(frames_to_inject_parsed) or any_background_ref:
new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + 0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
if any_background_ref:
new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:]
frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame]
if video_guide is not None or len(frames_to_inject_parsed) > 0 or model_def.get("forced_guide_mask_inputs", False):
any_mask = video_mask is not None or model_def.get("forced_guide_mask_inputs", False)
any_guide_padding = model_def.get("pad_guide_video", False)
from shared.utils.utils import prepare_video_guide_and_mask
src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed] + ([] if video_guide_processed2 is None else [video_guide_processed2]),
[video_mask_processed] + ([] if video_mask_processed2 is None else [video_mask_processed2]),
None if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else pre_video_guide,
image_size, current_video_length, latent_size,
any_mask, any_guide_padding, guide_inpaint_color,
keep_frames_parsed, frames_to_inject_parsed , outpainting_dims)
video_guide_processed = video_guide_processed2 = video_mask_processed = video_mask_processed2 = None
if len(src_videos) == 1:
src_video, src_video2, src_mask, src_mask2 = src_videos[0], None, src_masks[0], None
else:
src_video, src_video2 = src_videos
src_mask, src_mask2 = src_masks
src_videos = src_masks = None
if src_video is None:
abort = True
break
if src_faces is not None:
if src_faces.shape[1] < src_video.shape[1]:
src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1)
else:
new_image_refs += image_refs[nb_frames_positions:]
refresh_preview["image_refs"] = new_image_refs
new_image_refs = None
if sample_fit_canvas != None:
image_size = src_video[0].shape[-2:]
sample_fit_canvas = None
src_faces = src_faces[:, :src_video.shape[1]]
if video_guide is not None or len(frames_to_inject_parsed) > 0:
if args.save_masks:
if src_video is not None: save_video( src_video, "masked_frames.mp4", fps)
if src_video2 is not None: save_video( src_video2, "masked_frames2.mp4", fps)
if any_mask: save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
if video_guide is not None:
preview_frame_no = 0 if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else (guide_start_frame - window_start_frame)
refresh_preview["video_guide"] = convert_tensor_to_image(src_video, preview_frame_no)
if src_video2 is not None:
refresh_preview["video_guide"] = [refresh_preview["video_guide"], convert_tensor_to_image(src_video2, preview_frame_no)]
if src_mask is not None and video_mask is not None:
refresh_preview["video_mask"] = convert_tensor_to_image(src_mask, preview_frame_no, mask_levels = True)
if src_ref_images is not None or nb_frames_positions:
if len(frames_to_inject_parsed):
new_image_refs = [convert_tensor_to_image(src_video, frame_no + (0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame)) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
else:
new_image_refs = []
if src_ref_images is not None:
new_image_refs += [convert_tensor_to_image(img) if torch.is_tensor(img) else img for img in src_ref_images ]
refresh_preview["image_refs"] = new_image_refs
new_image_refs = None
if len(refresh_preview) > 0:
new_inputs= locals()
@ -5339,8 +5225,6 @@ def generate_video(
pre_video_frame = pre_video_frame,
original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [],
image_refs_relative_size = image_refs_relative_size,
image_guide= new_image_guide,
image_mask= new_image_mask,
outpainting_dims = outpainting_dims,
)
except Exception as e:
@ -6320,7 +6204,10 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
pop += ["image_refs_relative_size"]
if not vace:
pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2"]
pop += ["frames_positions", "control_net_weight", "control_net_weight2"]
if model_def.get("video_guide_outpainting", None) is None:
pop += ["video_guide_outpainting"]
if not (vace or t2v):
pop += ["min_frames_if_references"]
@ -6506,13 +6393,6 @@ def eject_video_from_gallery(state, input_file_list, choice):
choice = min(choice, len(file_list))
return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0)
def has_video_file_extension(filename):
extension = os.path.splitext(filename)[-1].lower()
return extension in [".mp4"]
def has_image_file_extension(filename):
extension = os.path.splitext(filename)[-1].lower()
return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"]
def add_videos_to_gallery(state, input_file_list, choice, files_to_load):
gen = get_gen_info(state)
if files_to_load == None:
@ -7881,7 +7761,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
elif recammaster:
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", get_max_frames(81)), step=4, label="Number of frames (16 = 1s), locked", interactive= False, visible = True)
else:
min_frames, frames_step = get_model_min_frames_and_step(base_model_type)
min_frames, frames_step, _ = get_model_min_frames_and_step(base_model_type)
current_video_length = ui_defaults.get("video_length", 81 if get_model_family(base_model_type)=="wan" else 97)
@ -8059,7 +7939,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
MMAudio_neg_prompt = gr.Text(ui_defaults.get("MMAudio_neg_prompt", ""), label="Negative Prompt (1 or 2 keywords)")
with gr.Column(visible = (t2v or vace) and not fantasy) as audio_prompt_type_remux_row:
with gr.Column(visible = any_control_video) as audio_prompt_type_remux_row:
gr.Markdown("<B>You may transfer the existing audio tracks of a Control Video</B>")
audio_prompt_type_remux = gr.Dropdown(
choices=[
@ -8284,16 +8164,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_info = gr.HTML(visible=True, min_height=100, value=get_default_video_info())
with gr.Row(**default_visibility) as video_buttons_row:
video_info_extract_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm")
video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video )
video_info_to_video_source_btn = gr.Button("To Video Source", min_width= 1, size ="sm", visible = any_video_source)
video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video )
video_info_eject_video_btn = gr.Button("Eject Video", min_width= 1, size ="sm")
with gr.Row(**default_visibility) as image_buttons_row:
video_info_extract_image_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm")
video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", min_width= 1, visible = any_start_image )
video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", min_width= 1, visible = any_end_image)
video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image )
video_info_to_image_mask_btn = gr.Button("To Mask Image", min_width= 1, size ="sm", visible = any_image_mask and False)
video_info_to_reference_image_btn = gr.Button("To Reference Image", min_width= 1, size ="sm", visible = any_reference_image)
video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image )
video_info_eject_image_btn = gr.Button("Eject Image", min_width= 1, size ="sm")
with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab:
with gr.Group(elem_classes= "postprocess"):