mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Vace Contenders are in Town
This commit is contained in:
parent
84010bd861
commit
e28c95ae91
@ -20,6 +20,14 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
|
||||
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
|
||||
|
||||
## 🔥 Latest Updates :
|
||||
### September 23 2025: WanGP v8.7 - Here Are Two New Contenders in the Vace Arena !
|
||||
|
||||
So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
|
||||
- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion tranfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
|
||||
|
||||
- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package.
|
||||
|
||||
|
||||
### September 15 2025: WanGP v8.6 - Attack of the Clones
|
||||
|
||||
- The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
"model": {
|
||||
"name": "Wan2.2 Lucy Edit 5B",
|
||||
"architecture": "lucy_edit",
|
||||
"description": "Lucy Edit Dev is a video editing model that performs instruction-guided edits on videos using free-text prompts \u2014 it supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.",
|
||||
"description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors",
|
||||
@ -10,6 +10,7 @@
|
||||
],
|
||||
"group": "wan2_2"
|
||||
},
|
||||
"prompt": "change the clothes to red",
|
||||
"video_length": 81,
|
||||
"guidance_scale": 5,
|
||||
"flow_shift": 5,
|
||||
|
||||
16
defaults/lucy_edit_fastwan.json
Normal file
16
defaults/lucy_edit_fastwan.json
Normal file
@ -0,0 +1,16 @@
|
||||
{
|
||||
"model": {
|
||||
"name": "Wan2.2 FastWan Lucy Edit 5B",
|
||||
"architecture": "lucy_edit",
|
||||
"description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.",
|
||||
"URLs": "lucy_edit",
|
||||
"group": "wan2_2",
|
||||
"loras": "ti2v_2_2"
|
||||
},
|
||||
"prompt": "change the clothes to red",
|
||||
"video_length": 81,
|
||||
"guidance_scale": 1,
|
||||
"flow_shift": 3,
|
||||
"num_inference_steps": 5,
|
||||
"resolution": "1280x720"
|
||||
}
|
||||
@ -56,7 +56,7 @@ class family_handler():
|
||||
}
|
||||
|
||||
|
||||
extra_model_def["lock_image_refs_ratios"] = True
|
||||
extra_model_def["fit_into_canvas_image_refs"] = 0
|
||||
|
||||
return extra_model_def
|
||||
|
||||
|
||||
@ -142,8 +142,8 @@ class model_factory:
|
||||
n_prompt: str = None,
|
||||
sampling_steps: int = 20,
|
||||
input_ref_images = None,
|
||||
image_guide= None,
|
||||
image_mask= None,
|
||||
input_frames= None,
|
||||
input_masks= None,
|
||||
width= 832,
|
||||
height=480,
|
||||
embedded_guidance_scale: float = 2.5,
|
||||
@ -197,10 +197,12 @@ class model_factory:
|
||||
for new_img in input_ref_images[1:]:
|
||||
stiched = stitch_images(stiched, new_img)
|
||||
input_ref_images = [stiched]
|
||||
elif image_guide is not None:
|
||||
input_ref_images = [image_guide]
|
||||
elif input_frames is not None:
|
||||
input_ref_images = [convert_tensor_to_image(input_frames) ]
|
||||
else:
|
||||
input_ref_images = None
|
||||
image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True)
|
||||
|
||||
|
||||
if self.name in ['flux-dev-uso', 'flux-dev-umo'] :
|
||||
inp, height, width = prepare_multi_ip(
|
||||
@ -253,8 +255,8 @@ class model_factory:
|
||||
if image_mask is not None:
|
||||
from shared.utils.utils import convert_image_to_tensor
|
||||
img_msk_rebuilt = inp["img_msk_rebuilt"]
|
||||
img= convert_image_to_tensor(image_guide)
|
||||
x = img.squeeze(2) * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt
|
||||
img= input_frames.squeeze(1).unsqueeze(0) # convert_image_to_tensor(image_guide)
|
||||
x = img * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt
|
||||
|
||||
x = x.clamp(-1, 1)
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
@ -865,7 +865,7 @@ class HunyuanVideoSampler(Inference):
|
||||
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
|
||||
else:
|
||||
if input_frames != None:
|
||||
target_height, target_width = input_frames.shape[-3:-1]
|
||||
target_height, target_width = input_frames.shape[-2:]
|
||||
elif input_video != None:
|
||||
target_height, target_width = input_video.shape[-2:]
|
||||
|
||||
@ -894,9 +894,10 @@ class HunyuanVideoSampler(Inference):
|
||||
pixel_value_bg = input_video.unsqueeze(0)
|
||||
pixel_value_mask = torch.zeros_like(input_video).unsqueeze(0)
|
||||
if input_frames != None:
|
||||
pixel_value_video_bg = input_frames.permute(-1,0,1,2).unsqueeze(0).float()
|
||||
pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float()
|
||||
pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.)
|
||||
pixel_value_video_bg = input_frames.unsqueeze(0) #.permute(-1,0,1,2).unsqueeze(0).float()
|
||||
# pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.)
|
||||
# pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float()
|
||||
pixel_value_video_mask = input_masks.repeat(3,1,1,1).unsqueeze(0)
|
||||
if input_video != None:
|
||||
pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2)
|
||||
pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2)
|
||||
@ -908,10 +909,11 @@ class HunyuanVideoSampler(Inference):
|
||||
if pixel_value_bg.shape[2] < frame_num:
|
||||
padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] + list(pixel_value_bg.shape[3:])
|
||||
pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2)
|
||||
pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
|
||||
# pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
|
||||
pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 1, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
|
||||
|
||||
bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample()
|
||||
pixel_value_mask = pixel_value_mask.div_(127.5).add_(-1.)
|
||||
pixel_value_mask = pixel_value_mask.mul_(2).add_(-1.) # unmasked pixels is -1 (no 0 as usual) and masked is 1
|
||||
mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample()
|
||||
bg_latents = torch.cat([bg_latents, mask_latents], dim=1)
|
||||
bg_latents.mul_(self.vae.config.scaling_factor)
|
||||
|
||||
@ -35,6 +35,8 @@ class family_handler():
|
||||
"selection": ["", "A", "NA", "XA", "XNA"],
|
||||
}
|
||||
|
||||
extra_model_def["extra_control_frames"] = 1
|
||||
extra_model_def["dont_cat_preguide"]= True
|
||||
return extra_model_def
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -17,7 +17,7 @@ class family_handler():
|
||||
("Default", "default"),
|
||||
("Lightning", "lightning")],
|
||||
"guidance_max_phases" : 1,
|
||||
"lock_image_refs_ratios": True,
|
||||
"fit_into_canvas_image_refs": 0,
|
||||
}
|
||||
|
||||
if base_model_type in ["qwen_image_edit_20B"]:
|
||||
|
||||
@ -17,7 +17,7 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
from .pipeline_qwenimage import QwenImagePipeline
|
||||
from PIL import Image
|
||||
from shared.utils.utils import calculate_new_dimensions
|
||||
from shared.utils.utils import calculate_new_dimensions, convert_tensor_to_image
|
||||
|
||||
def stitch_images(img1, img2):
|
||||
# Resize img2 to match img1's height
|
||||
@ -103,8 +103,8 @@ class model_factory():
|
||||
n_prompt = None,
|
||||
sampling_steps: int = 20,
|
||||
input_ref_images = None,
|
||||
image_guide= None,
|
||||
image_mask= None,
|
||||
input_frames= None,
|
||||
input_masks= None,
|
||||
width= 832,
|
||||
height=480,
|
||||
guide_scale: float = 4,
|
||||
@ -179,8 +179,10 @@ class model_factory():
|
||||
|
||||
if n_prompt is None or len(n_prompt) == 0:
|
||||
n_prompt= "text, watermark, copyright, blurry, low resolution"
|
||||
if image_guide is not None:
|
||||
input_ref_images = [image_guide]
|
||||
|
||||
image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True)
|
||||
if input_frames is not None:
|
||||
input_ref_images = [convert_tensor_to_image(input_frames) ]
|
||||
elif input_ref_images is not None:
|
||||
# image stiching method
|
||||
stiched = input_ref_images[0]
|
||||
|
||||
143
models/wan/animate/animate_utils.py
Normal file
143
models/wan/animate/animate_utils.py
Normal file
@ -0,0 +1,143 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import torch
|
||||
import numbers
|
||||
from peft import LoraConfig
|
||||
|
||||
|
||||
def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"):
|
||||
target_modules = []
|
||||
for name, module in transformer.named_modules():
|
||||
if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear):
|
||||
target_modules.append(name)
|
||||
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=alpha,
|
||||
init_lora_weights=init_lora_weights,
|
||||
target_modules=target_modules,
|
||||
)
|
||||
return transformer_lora_config
|
||||
|
||||
|
||||
|
||||
class TensorList(object):
|
||||
|
||||
def __init__(self, tensors):
|
||||
"""
|
||||
tensors: a list of torch.Tensor objects. No need to have uniform shape.
|
||||
"""
|
||||
assert isinstance(tensors, (list, tuple))
|
||||
assert all(isinstance(u, torch.Tensor) for u in tensors)
|
||||
assert len(set([u.ndim for u in tensors])) == 1
|
||||
assert len(set([u.dtype for u in tensors])) == 1
|
||||
assert len(set([u.device for u in tensors])) == 1
|
||||
self.tensors = tensors
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
return TensorList([u.to(*args, **kwargs) for u in self.tensors])
|
||||
|
||||
def size(self, dim):
|
||||
assert dim == 0, 'only support get the 0th size'
|
||||
return len(self.tensors)
|
||||
|
||||
def pow(self, *args, **kwargs):
|
||||
return TensorList([u.pow(*args, **kwargs) for u in self.tensors])
|
||||
|
||||
def squeeze(self, dim):
|
||||
assert dim != 0
|
||||
if dim > 0:
|
||||
dim -= 1
|
||||
return TensorList([u.squeeze(dim) for u in self.tensors])
|
||||
|
||||
def type(self, *args, **kwargs):
|
||||
return TensorList([u.type(*args, **kwargs) for u in self.tensors])
|
||||
|
||||
def type_as(self, other):
|
||||
assert isinstance(other, (torch.Tensor, TensorList))
|
||||
if isinstance(other, torch.Tensor):
|
||||
return TensorList([u.type_as(other) for u in self.tensors])
|
||||
else:
|
||||
return TensorList([u.type(other.dtype) for u in self.tensors])
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.tensors[0].dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.tensors[0].device
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return 1 + self.tensors[0].ndim
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.tensors[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.tensors)
|
||||
|
||||
def __add__(self, other):
|
||||
return self._apply(other, lambda u, v: u + v)
|
||||
|
||||
def __radd__(self, other):
|
||||
return self._apply(other, lambda u, v: v + u)
|
||||
|
||||
def __sub__(self, other):
|
||||
return self._apply(other, lambda u, v: u - v)
|
||||
|
||||
def __rsub__(self, other):
|
||||
return self._apply(other, lambda u, v: v - u)
|
||||
|
||||
def __mul__(self, other):
|
||||
return self._apply(other, lambda u, v: u * v)
|
||||
|
||||
def __rmul__(self, other):
|
||||
return self._apply(other, lambda u, v: v * u)
|
||||
|
||||
def __floordiv__(self, other):
|
||||
return self._apply(other, lambda u, v: u // v)
|
||||
|
||||
def __truediv__(self, other):
|
||||
return self._apply(other, lambda u, v: u / v)
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
return self._apply(other, lambda u, v: v // u)
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return self._apply(other, lambda u, v: v / u)
|
||||
|
||||
def __pow__(self, other):
|
||||
return self._apply(other, lambda u, v: u ** v)
|
||||
|
||||
def __rpow__(self, other):
|
||||
return self._apply(other, lambda u, v: v ** u)
|
||||
|
||||
def __neg__(self):
|
||||
return TensorList([-u for u in self.tensors])
|
||||
|
||||
def __iter__(self):
|
||||
for tensor in self.tensors:
|
||||
yield tensor
|
||||
|
||||
def __repr__(self):
|
||||
return 'TensorList: \n' + repr(self.tensors)
|
||||
|
||||
def _apply(self, other, op):
|
||||
if isinstance(other, (list, tuple, TensorList)) or (
|
||||
isinstance(other, torch.Tensor) and (
|
||||
other.numel() > 1 or other.ndim > 1
|
||||
)
|
||||
):
|
||||
assert len(other) == len(self.tensors)
|
||||
return TensorList([op(u, v) for u, v in zip(self.tensors, other)])
|
||||
elif isinstance(other, numbers.Number) or (
|
||||
isinstance(other, torch.Tensor) and (
|
||||
other.numel() == 1 and other.ndim <= 1
|
||||
)
|
||||
):
|
||||
return TensorList([op(u, other) for u in self.tensors])
|
||||
else:
|
||||
raise TypeError(
|
||||
f'unsupported operand for *: "TensorList" and "{type(other)}"'
|
||||
)
|
||||
382
models/wan/animate/face_blocks.py
Normal file
382
models/wan/animate/face_blocks.py
Normal file
@ -0,0 +1,382 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
from torch import nn
|
||||
import torch
|
||||
from typing import Tuple, Optional
|
||||
from einops import rearrange
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from shared.attention import pay_attention
|
||||
|
||||
MEMORY_LAYOUT = {
|
||||
"flash": (
|
||||
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
||||
lambda x: x,
|
||||
),
|
||||
"torch": (
|
||||
lambda x: x.transpose(1, 2),
|
||||
lambda x: x.transpose(1, 2),
|
||||
),
|
||||
"vanilla": (
|
||||
lambda x: x.transpose(1, 2),
|
||||
lambda x: x.transpose(1, 2),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
mode="torch",
|
||||
drop_rate=0,
|
||||
attn_mask=None,
|
||||
causal=False,
|
||||
max_seqlen_q=None,
|
||||
batch_size=1,
|
||||
):
|
||||
"""
|
||||
Perform QKV self attention.
|
||||
|
||||
Args:
|
||||
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
||||
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
||||
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
||||
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
||||
drop_rate (float): Dropout rate in attention map. (default: 0)
|
||||
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
||||
(default: None)
|
||||
causal (bool): Whether to use causal attention. (default: False)
|
||||
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
||||
used to index into q.
|
||||
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
||||
used to index into kv.
|
||||
max_seqlen_q (int): The maximum sequence length in the batch of q.
|
||||
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
||||
"""
|
||||
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
||||
|
||||
if mode == "torch":
|
||||
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
||||
attn_mask = attn_mask.to(q.dtype)
|
||||
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
|
||||
|
||||
elif mode == "flash":
|
||||
x = flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
)
|
||||
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
|
||||
elif mode == "vanilla":
|
||||
scale_factor = 1 / math.sqrt(q.size(-1))
|
||||
|
||||
b, a, s, _ = q.shape
|
||||
s1 = k.size(2)
|
||||
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
||||
if causal:
|
||||
# Only applied to self attention
|
||||
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
|
||||
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
attn_bias.to(q.dtype)
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
||||
else:
|
||||
attn_bias += attn_mask
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
||||
attn += attn_bias
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = torch.dropout(attn, p=drop_rate, train=True)
|
||||
x = attn @ v
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
||||
|
||||
x = post_attn_layout(x)
|
||||
b, s, a, d = x.shape
|
||||
out = x.reshape(b, s, -1)
|
||||
return out
|
||||
|
||||
|
||||
class CausalConv1d(nn.Module):
|
||||
|
||||
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
padding = (kernel_size - 1, 0) # T
|
||||
self.time_causal_padding = padding
|
||||
|
||||
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
|
||||
class FaceEncoder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
|
||||
self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.act = nn.SiLU()
|
||||
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
|
||||
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
|
||||
|
||||
self.out_proj = nn.Linear(1024, hidden_dim)
|
||||
self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
b, c, t = x.shape
|
||||
|
||||
x = self.conv1_local(x)
|
||||
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
|
||||
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
x = self.conv2(x)
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
x = self.conv3(x)
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
x = self.norm3(x)
|
||||
x = self.act(x)
|
||||
x = self.out_proj(x)
|
||||
x = rearrange(x, "(b n) t c -> b t n c", b=b)
|
||||
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
|
||||
x = torch.cat([x, padding], dim=-2)
|
||||
x_local = x.clone()
|
||||
|
||||
return x_local
|
||||
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
elementwise_affine=True,
|
||||
eps: float = 1e-6,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
|
||||
"""
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
if hasattr(self, "weight"):
|
||||
output = output * self.weight
|
||||
return output
|
||||
|
||||
|
||||
def get_norm_layer(norm_layer):
|
||||
"""
|
||||
Get the normalization layer.
|
||||
|
||||
Args:
|
||||
norm_layer (str): The type of normalization layer.
|
||||
|
||||
Returns:
|
||||
norm_layer (nn.Module): The normalization layer.
|
||||
"""
|
||||
if norm_layer == "layer":
|
||||
return nn.LayerNorm
|
||||
elif norm_layer == "rms":
|
||||
return RMSNorm
|
||||
else:
|
||||
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
||||
|
||||
|
||||
class FaceAdapter(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
heads_num: int,
|
||||
qk_norm: bool = True,
|
||||
qk_norm_type: str = "rms",
|
||||
num_adapter_layers: int = 1,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_dim
|
||||
self.heads_num = heads_num
|
||||
self.fuser_blocks = nn.ModuleList(
|
||||
[
|
||||
FaceBlock(
|
||||
self.hidden_size,
|
||||
self.heads_num,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
**factory_kwargs,
|
||||
)
|
||||
for _ in range(num_adapter_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
motion_embed: torch.Tensor,
|
||||
idx: int,
|
||||
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
|
||||
|
||||
|
||||
|
||||
class FaceBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
heads_num: int,
|
||||
qk_norm: bool = True,
|
||||
qk_norm_type: str = "rms",
|
||||
qk_scale: float = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
|
||||
self.deterministic = False
|
||||
self.hidden_size = hidden_size
|
||||
self.heads_num = heads_num
|
||||
head_dim = hidden_size // heads_num
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
|
||||
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
||||
|
||||
self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
||||
|
||||
qk_norm_layer = get_norm_layer(qk_norm_type)
|
||||
self.q_norm = (
|
||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
||||
)
|
||||
self.k_norm = (
|
||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
||||
)
|
||||
|
||||
self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
motion_vec: torch.Tensor,
|
||||
motion_mask: Optional[torch.Tensor] = None,
|
||||
use_context_parallel=False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
B, T, N, C = motion_vec.shape
|
||||
T_comp = T
|
||||
|
||||
x_motion = self.pre_norm_motion(motion_vec)
|
||||
x_feat = self.pre_norm_feat(x)
|
||||
|
||||
kv = self.linear1_kv(x_motion)
|
||||
q = self.linear1_q(x_feat)
|
||||
|
||||
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
|
||||
q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
|
||||
|
||||
# Apply QK-Norm if needed.
|
||||
q = self.q_norm(q).to(v)
|
||||
k = self.k_norm(k).to(v)
|
||||
|
||||
k = rearrange(k, "B L N H D -> (B L) N H D")
|
||||
v = rearrange(v, "B L N H D -> (B L) N H D")
|
||||
|
||||
if use_context_parallel:
|
||||
q = gather_forward(q, dim=1)
|
||||
|
||||
q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)
|
||||
# Compute attention.
|
||||
# Size([batches, tokens, heads, head_features])
|
||||
qkv_list = [q, k, v]
|
||||
del q,k,v
|
||||
attn = pay_attention(qkv_list)
|
||||
# attn = attention(
|
||||
# q,
|
||||
# k,
|
||||
# v,
|
||||
# max_seqlen_q=q.shape[1],
|
||||
# batch_size=q.shape[0],
|
||||
# )
|
||||
|
||||
attn = attn.reshape(*attn.shape[:2], -1)
|
||||
attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
|
||||
# if use_context_parallel:
|
||||
# attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()]
|
||||
|
||||
output = self.linear2(attn)
|
||||
|
||||
if motion_mask is not None:
|
||||
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
|
||||
|
||||
return output
|
||||
31
models/wan/animate/model_animate.py
Normal file
31
models/wan/animate/model_animate.py
Normal file
@ -0,0 +1,31 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import math
|
||||
import types
|
||||
from copy import deepcopy
|
||||
from einops import rearrange
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.nn as nn
|
||||
|
||||
def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
|
||||
pose_latents = self.pose_patch_embedding(pose_latents)
|
||||
x[:, :, 1:] += pose_latents
|
||||
|
||||
b,c,T,h,w = face_pixel_values.shape
|
||||
face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
|
||||
encode_bs = 8
|
||||
face_pixel_values_tmp = []
|
||||
for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
|
||||
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
|
||||
|
||||
motion_vec = torch.cat(face_pixel_values_tmp)
|
||||
|
||||
motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
|
||||
motion_vec = self.face_encoder(motion_vec)
|
||||
|
||||
B, L, H, C = motion_vec.shape
|
||||
pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
|
||||
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
|
||||
return x, motion_vec
|
||||
308
models/wan/animate/motion_encoder.py
Normal file
308
models/wan/animate/motion_encoder.py
Normal file
@ -0,0 +1,308 @@
|
||||
# Modified from ``https://github.com/wyhsirius/LIA``
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import math
|
||||
|
||||
def custom_qr(input_tensor):
|
||||
original_dtype = input_tensor.dtype
|
||||
if original_dtype in [torch.bfloat16, torch.float16]:
|
||||
q, r = torch.linalg.qr(input_tensor.to(torch.float32))
|
||||
return q.to(original_dtype), r.to(original_dtype)
|
||||
return torch.linalg.qr(input_tensor)
|
||||
|
||||
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
||||
return F.leaky_relu(input + bias, negative_slope) * scale
|
||||
|
||||
|
||||
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
|
||||
_, minor, in_h, in_w = input.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = input.view(-1, minor, in_h, 1, in_w, 1)
|
||||
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
|
||||
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
|
||||
|
||||
out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
||||
out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
|
||||
max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
|
||||
|
||||
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
|
||||
return out[:, :, ::down_y, ::down_x]
|
||||
|
||||
|
||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
||||
|
||||
|
||||
def make_kernel(k):
|
||||
k = torch.tensor(k, dtype=torch.float32)
|
||||
if k.ndim == 1:
|
||||
k = k[None, :] * k[:, None]
|
||||
k /= k.sum()
|
||||
return k
|
||||
|
||||
|
||||
class FusedLeakyReLU(nn.Module):
|
||||
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
||||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
||||
self.negative_slope = negative_slope
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, input):
|
||||
out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
||||
return out
|
||||
|
||||
|
||||
class Blur(nn.Module):
|
||||
def __init__(self, kernel, pad, upsample_factor=1):
|
||||
super().__init__()
|
||||
|
||||
kernel = make_kernel(kernel)
|
||||
|
||||
if upsample_factor > 1:
|
||||
kernel = kernel * (upsample_factor ** 2)
|
||||
|
||||
self.register_buffer('kernel', kernel)
|
||||
|
||||
self.pad = pad
|
||||
|
||||
def forward(self, input):
|
||||
return upfirdn2d(input, self.kernel, pad=self.pad)
|
||||
|
||||
|
||||
class ScaledLeakyReLU(nn.Module):
|
||||
def __init__(self, negative_slope=0.2):
|
||||
super().__init__()
|
||||
|
||||
self.negative_slope = negative_slope
|
||||
|
||||
def forward(self, input):
|
||||
return F.leaky_relu(input, negative_slope=self.negative_slope)
|
||||
|
||||
|
||||
class EqualConv2d(nn.Module):
|
||||
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
|
||||
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
||||
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_channel))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
||||
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
||||
)
|
||||
|
||||
|
||||
class EqualLinear(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.activation = activation
|
||||
|
||||
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
||||
self.lr_mul = lr_mul
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
if self.activation:
|
||||
out = F.linear(input, self.weight * self.scale)
|
||||
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
||||
else:
|
||||
out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
|
||||
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
|
||||
|
||||
|
||||
class ConvLayer(nn.Sequential):
|
||||
def __init__(
|
||||
self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
downsample=False,
|
||||
blur_kernel=[1, 3, 3, 1],
|
||||
bias=True,
|
||||
activate=True,
|
||||
):
|
||||
layers = []
|
||||
|
||||
if downsample:
|
||||
factor = 2
|
||||
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||
pad0 = (p + 1) // 2
|
||||
pad1 = p // 2
|
||||
|
||||
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
||||
|
||||
stride = 2
|
||||
self.padding = 0
|
||||
|
||||
else:
|
||||
stride = 1
|
||||
self.padding = kernel_size // 2
|
||||
|
||||
layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
|
||||
bias=bias and not activate))
|
||||
|
||||
if activate:
|
||||
if bias:
|
||||
layers.append(FusedLeakyReLU(out_channel))
|
||||
else:
|
||||
layers.append(ScaledLeakyReLU(0.2))
|
||||
|
||||
super().__init__(*layers)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
||||
super().__init__()
|
||||
|
||||
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
||||
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
||||
|
||||
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.conv1(input)
|
||||
out = self.conv2(out)
|
||||
|
||||
skip = self.skip(input)
|
||||
out = (out + skip) / math.sqrt(2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class EncoderApp(nn.Module):
|
||||
def __init__(self, size, w_dim=512):
|
||||
super(EncoderApp, self).__init__()
|
||||
|
||||
channels = {
|
||||
4: 512,
|
||||
8: 512,
|
||||
16: 512,
|
||||
32: 512,
|
||||
64: 256,
|
||||
128: 128,
|
||||
256: 64,
|
||||
512: 32,
|
||||
1024: 16
|
||||
}
|
||||
|
||||
self.w_dim = w_dim
|
||||
log_size = int(math.log(size, 2))
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
self.convs.append(ConvLayer(3, channels[size], 1))
|
||||
|
||||
in_channel = channels[size]
|
||||
for i in range(log_size, 2, -1):
|
||||
out_channel = channels[2 ** (i - 1)]
|
||||
self.convs.append(ResBlock(in_channel, out_channel))
|
||||
in_channel = out_channel
|
||||
|
||||
self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
res = []
|
||||
h = x
|
||||
for conv in self.convs:
|
||||
h = conv(h)
|
||||
res.append(h)
|
||||
|
||||
return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, size, dim=512, dim_motion=20):
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
# appearance netmork
|
||||
self.net_app = EncoderApp(size, dim)
|
||||
|
||||
# motion network
|
||||
fc = [EqualLinear(dim, dim)]
|
||||
for i in range(3):
|
||||
fc.append(EqualLinear(dim, dim))
|
||||
|
||||
fc.append(EqualLinear(dim, dim_motion))
|
||||
self.fc = nn.Sequential(*fc)
|
||||
|
||||
def enc_app(self, x):
|
||||
h_source = self.net_app(x)
|
||||
return h_source
|
||||
|
||||
def enc_motion(self, x):
|
||||
h, _ = self.net_app(x)
|
||||
h_motion = self.fc(h)
|
||||
return h_motion
|
||||
|
||||
|
||||
class Direction(nn.Module):
|
||||
def __init__(self, motion_dim):
|
||||
super(Direction, self).__init__()
|
||||
self.weight = nn.Parameter(torch.randn(512, motion_dim))
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
weight = self.weight + 1e-8
|
||||
Q, R = custom_qr(weight)
|
||||
if input is None:
|
||||
return Q
|
||||
else:
|
||||
input_diag = torch.diag_embed(input) # alpha, diagonal matrix
|
||||
out = torch.matmul(input_diag, Q.T)
|
||||
out = torch.sum(out, dim=1)
|
||||
return out
|
||||
|
||||
|
||||
class Synthesis(nn.Module):
|
||||
def __init__(self, motion_dim):
|
||||
super(Synthesis, self).__init__()
|
||||
self.direction = Direction(motion_dim)
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, size, style_dim=512, motion_dim=20):
|
||||
super().__init__()
|
||||
|
||||
self.enc = Encoder(size, style_dim, motion_dim)
|
||||
self.dec = Synthesis(motion_dim)
|
||||
|
||||
def get_motion(self, img):
|
||||
#motion_feat = self.enc.enc_motion(img)
|
||||
# motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
|
||||
with torch.cuda.amp.autocast(dtype=torch.float32):
|
||||
motion_feat = self.enc.enc_motion(img)
|
||||
motion = self.dec.direction(motion_feat)
|
||||
return motion
|
||||
@ -203,10 +203,7 @@ class WanAny2V:
|
||||
self.use_timestep_transform = True
|
||||
|
||||
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = None):
|
||||
if ref_images is None:
|
||||
ref_images = [None] * len(frames)
|
||||
else:
|
||||
assert len(frames) == len(ref_images)
|
||||
ref_images = [ref_images] * len(frames)
|
||||
|
||||
if masks is None:
|
||||
latents = self.vae.encode(frames, tile_size = tile_size)
|
||||
@ -238,11 +235,7 @@ class WanAny2V:
|
||||
return cat_latents
|
||||
|
||||
def vace_encode_masks(self, masks, ref_images=None):
|
||||
if ref_images is None:
|
||||
ref_images = [None] * len(masks)
|
||||
else:
|
||||
assert len(masks) == len(ref_images)
|
||||
|
||||
ref_images = [ref_images] * len(masks)
|
||||
result_masks = []
|
||||
for mask, refs in zip(masks, ref_images):
|
||||
c, depth, height, width = mask.shape
|
||||
@ -270,79 +263,6 @@ class WanAny2V:
|
||||
result_masks.append(mask)
|
||||
return result_masks
|
||||
|
||||
def vace_latent(self, z, m):
|
||||
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
||||
|
||||
|
||||
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
|
||||
image_sizes = []
|
||||
trim_video_guide = len(keep_video_guide_frames)
|
||||
def conv_tensor(t, device):
|
||||
return t.float().div_(127.5).add_(-1).permute(3, 0, 1, 2).to(device)
|
||||
|
||||
for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
|
||||
prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
|
||||
num_frames = total_frames - prepend_count
|
||||
num_frames = min(num_frames, trim_video_guide) if trim_video_guide > 0 and sub_src_video != None else num_frames
|
||||
if sub_src_mask is not None and sub_src_video is not None:
|
||||
src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
|
||||
src_mask[i] = conv_tensor(sub_src_mask[:num_frames], device)
|
||||
# src_video is [-1, 1] (at this function output), 0 = inpainting area (in fact 127 in [0, 255])
|
||||
# src_mask is [-1, 1] (at this function output), 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
|
||||
if prepend_count > 0:
|
||||
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
|
||||
src_mask[i] = torch.cat( [torch.full_like(sub_pre_src_video, -1.0), src_mask[i]] ,1)
|
||||
src_video_shape = src_video[i].shape
|
||||
if src_video_shape[1] != total_frames:
|
||||
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
||||
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
||||
src_mask[i] = torch.clamp((src_mask[i][:, :, :, :] + 1) / 2, min=0, max=1)
|
||||
image_sizes.append(src_video[i].shape[2:])
|
||||
elif sub_src_video is None:
|
||||
if prepend_count > 0:
|
||||
src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1)
|
||||
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1)
|
||||
else:
|
||||
src_video[i] = torch.zeros((3, total_frames, image_size[0], image_size[1]), device=device)
|
||||
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
||||
image_sizes.append(image_size)
|
||||
else:
|
||||
src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
|
||||
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
||||
if prepend_count > 0:
|
||||
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
|
||||
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
|
||||
src_video_shape = src_video[i].shape
|
||||
if src_video_shape[1] != total_frames:
|
||||
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
||||
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
||||
image_sizes.append(src_video[i].shape[2:])
|
||||
for k, keep in enumerate(keep_video_guide_frames):
|
||||
if not keep:
|
||||
pos = prepend_count + k
|
||||
src_video[i][:, pos:pos+1] = 0
|
||||
src_mask[i][:, pos:pos+1] = 1
|
||||
|
||||
for k, frame in enumerate(inject_frames):
|
||||
if frame != None:
|
||||
pos = prepend_count + k
|
||||
src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True)
|
||||
|
||||
|
||||
self.background_mask = None
|
||||
for i, ref_images in enumerate(src_ref_images):
|
||||
if ref_images is not None:
|
||||
image_size = image_sizes[i]
|
||||
for j, ref_img in enumerate(ref_images):
|
||||
if ref_img is not None and not torch.is_tensor(ref_img):
|
||||
if j==0 and any_background_ref:
|
||||
if self.background_mask == None: self.background_mask = [None] * len(src_ref_images)
|
||||
src_ref_images[i][j], self.background_mask[i] = fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True)
|
||||
else:
|
||||
src_ref_images[i][j], _ = fit_image_into_canvas(ref_img, image_size, 1, device)
|
||||
if self.background_mask != None:
|
||||
self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref
|
||||
return src_video, src_mask, src_ref_images
|
||||
|
||||
def get_vae_latents(self, ref_images, device, tile_size= 0):
|
||||
ref_vae_latents = []
|
||||
@ -369,7 +289,9 @@ class WanAny2V:
|
||||
def generate(self,
|
||||
input_prompt,
|
||||
input_frames= None,
|
||||
input_frames2= None,
|
||||
input_masks = None,
|
||||
input_masks2 = None,
|
||||
input_ref_images = None,
|
||||
input_ref_masks = None,
|
||||
input_faces = None,
|
||||
@ -615,21 +537,22 @@ class WanAny2V:
|
||||
pose_pixels = input_frames * input_masks
|
||||
input_masks = 1. - input_masks
|
||||
pose_pixels -= input_masks
|
||||
save_video(pose_pixels, "pose.mp4")
|
||||
pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0)
|
||||
input_frames = input_frames * input_masks
|
||||
if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames
|
||||
if prefix_frames_count > 0:
|
||||
input_frames[:, :prefix_frames_count] = input_video
|
||||
input_masks[:, :prefix_frames_count] = 1
|
||||
save_video(input_frames, "input_frames.mp4")
|
||||
save_video(input_masks, "input_masks.mp4", value_range=(0,1))
|
||||
# save_video(pose_pixels, "pose.mp4")
|
||||
# save_video(input_frames, "input_frames.mp4")
|
||||
# save_video(input_masks, "input_masks.mp4", value_range=(0,1))
|
||||
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
|
||||
msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device)
|
||||
msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device)
|
||||
msk = torch.concat([msk_ref, msk_control], dim=1)
|
||||
clip_image_start = image_ref = convert_image_to_tensor(input_ref_images[0]).to(self.device)
|
||||
lat_y = torch.concat(self.vae.encode([image_ref.unsqueeze(1).to(self.device), input_frames.to(self.device)], VAE_tile_size), dim=1)
|
||||
image_ref = input_ref_images[0].to(self.device)
|
||||
clip_image_start = image_ref.squeeze(1)
|
||||
lat_y = torch.concat(self.vae.encode([image_ref, input_frames.to(self.device)], VAE_tile_size), dim=1)
|
||||
y = torch.concat([msk, lat_y])
|
||||
kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)})
|
||||
lat_y = msk = msk_control = msk_ref = pose_pixels = None
|
||||
@ -701,12 +624,11 @@ class WanAny2V:
|
||||
|
||||
# Phantom
|
||||
if phantom:
|
||||
input_ref_images_neg = None
|
||||
if input_ref_images != None: # Phantom Ref images
|
||||
input_ref_images = self.get_vae_latents(input_ref_images, self.device)
|
||||
input_ref_images_neg = torch.zeros_like(input_ref_images)
|
||||
ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0
|
||||
trim_frames = input_ref_images.shape[1]
|
||||
lat_input_ref_images_neg = None
|
||||
if input_ref_images is not None: # Phantom Ref images
|
||||
lat_input_ref_images = self.get_vae_latents(input_ref_images, self.device)
|
||||
lat_input_ref_images_neg = torch.zeros_like(lat_input_ref_images)
|
||||
ref_images_count = trim_frames = lat_input_ref_images.shape[1]
|
||||
|
||||
if ti2v:
|
||||
if input_video is None:
|
||||
@ -721,25 +643,23 @@ class WanAny2V:
|
||||
# Vace
|
||||
if vace :
|
||||
# vace context encode
|
||||
input_frames = [u.to(self.device) for u in input_frames]
|
||||
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
|
||||
input_masks = [u.to(self.device) for u in input_masks]
|
||||
input_frames = [input_frames.to(self.device)] +([] if input_frames2 is None else [input_frames2.to(self.device)])
|
||||
input_masks = [input_masks.to(self.device)] + ([] if input_masks2 is None else [input_masks2.to(self.device)])
|
||||
input_ref_images = None if input_ref_images is None else [ u.to(self.device) for u in input_ref_images]
|
||||
input_ref_masks = None if input_ref_masks is None else [ None if u is None else u.to(self.device) for u in input_ref_masks]
|
||||
ref_images_before = True
|
||||
if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask]
|
||||
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
|
||||
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
||||
if self.background_mask != None:
|
||||
color_reference_frame = input_ref_images[0][0].clone()
|
||||
zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size )
|
||||
mbg = self.vace_encode_masks(self.background_mask, None)
|
||||
if input_ref_masks is not None and len(input_ref_masks) > 0 and input_ref_masks[0] is not None:
|
||||
color_reference_frame = input_ref_images[0].clone()
|
||||
zbg = self.vace_encode_frames( input_ref_images[:1] * len(input_frames), None, masks=input_ref_masks[0], tile_size = VAE_tile_size )
|
||||
mbg = self.vace_encode_masks(input_ref_masks[:1] * len(input_frames), None)
|
||||
for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg):
|
||||
zz0[:, 0:1] = zzbg
|
||||
mm0[:, 0:1] = mmbg
|
||||
|
||||
self.background_mask = zz0 = mm0 = zzbg = mmbg = None
|
||||
z = self.vace_latent(z0, m0)
|
||||
|
||||
ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0
|
||||
zz0 = mm0 = zzbg = mmbg = None
|
||||
z = [torch.cat([zz, mm], dim=0) for zz, mm in zip(z0, m0)]
|
||||
ref_images_count = len(input_ref_images) if input_ref_images is not None and input_ref_images is not None else 0
|
||||
context_scale = context_scale if context_scale != None else [1.0] * len(z)
|
||||
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count })
|
||||
if overlapped_latents != None :
|
||||
@ -747,15 +667,8 @@ class WanAny2V:
|
||||
extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0)
|
||||
if prefix_frames_count > 0:
|
||||
color_reference_frame = input_frames[0][:, prefix_frames_count -1:prefix_frames_count].clone()
|
||||
|
||||
target_shape = list(z0[0].shape)
|
||||
target_shape[0] = int(target_shape[0] / 2)
|
||||
lat_h, lat_w = target_shape[-2:]
|
||||
height = self.vae_stride[1] * lat_h
|
||||
width = self.vae_stride[2] * lat_w
|
||||
|
||||
else:
|
||||
target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2])
|
||||
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
|
||||
target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, lat_h, lat_w)
|
||||
|
||||
if multitalk:
|
||||
if audio_proj is None:
|
||||
@ -860,7 +773,9 @@ class WanAny2V:
|
||||
apg_norm_threshold = 55
|
||||
text_momentumbuffer = MomentumBuffer(apg_momentum)
|
||||
audio_momentumbuffer = MomentumBuffer(apg_momentum)
|
||||
|
||||
input_frames = input_frames2 = input_masks =input_masks2 = input_video = input_ref_images = input_ref_masks = pre_video_frame = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# denoising
|
||||
trans = self.model
|
||||
@ -878,7 +793,7 @@ class WanAny2V:
|
||||
kwargs.update({"t": timestep, "current_step": start_step_no + i})
|
||||
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
|
||||
|
||||
if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step:
|
||||
if denoising_strength < 1 and i <= injection_denoising_step:
|
||||
sigma = t / 1000
|
||||
noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
|
||||
if inject_from_start:
|
||||
@ -912,8 +827,8 @@ class WanAny2V:
|
||||
any_guidance = guide_scale != 1
|
||||
if phantom:
|
||||
gen_args = {
|
||||
"x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 +
|
||||
[ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]),
|
||||
"x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 +
|
||||
[ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]),
|
||||
"context": [context, context_null, context_null] ,
|
||||
}
|
||||
elif fantasy:
|
||||
|
||||
@ -21,6 +21,7 @@ class family_handler():
|
||||
extra_model_def["fps"] =fps
|
||||
extra_model_def["frames_minimum"] = 17
|
||||
extra_model_def["frames_steps"] = 20
|
||||
extra_model_def["latent_size"] = 4
|
||||
extra_model_def["sliding_window"] = True
|
||||
extra_model_def["skip_layer_guidance"] = True
|
||||
extra_model_def["tea_cache"] = True
|
||||
|
||||
@ -114,7 +114,6 @@ class family_handler():
|
||||
"tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels),
|
||||
"mag_cache" : True,
|
||||
"keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
|
||||
"convert_image_guide_to_video" : True,
|
||||
"sample_solvers":[
|
||||
("unipc", "unipc"),
|
||||
("euler", "euler"),
|
||||
@ -175,6 +174,8 @@ class family_handler():
|
||||
extra_model_def["forced_guide_mask_inputs"] = True
|
||||
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)"
|
||||
extra_model_def["background_ref_outpainted"] = False
|
||||
extra_model_def["return_image_refs_tensor"] = True
|
||||
extra_model_def["guide_inpaint_color"] = 0
|
||||
|
||||
|
||||
|
||||
@ -196,15 +197,15 @@ class family_handler():
|
||||
"letters_filter": "KFI",
|
||||
}
|
||||
|
||||
extra_model_def["lock_image_refs_ratios"] = True
|
||||
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames"
|
||||
extra_model_def["video_guide_outpainting"] = [0,1]
|
||||
extra_model_def["pad_guide_video"] = True
|
||||
extra_model_def["guide_inpaint_color"] = 127.5
|
||||
extra_model_def["forced_guide_mask_inputs"] = True
|
||||
extra_model_def["return_image_refs_tensor"] = True
|
||||
|
||||
if base_model_type in ["standin"]:
|
||||
extra_model_def["lock_image_refs_ratios"] = True
|
||||
extra_model_def["fit_into_canvas_image_refs"] = 0
|
||||
extra_model_def["image_ref_choices"] = {
|
||||
"choices": [
|
||||
("No Reference Image", ""),
|
||||
@ -480,6 +481,7 @@ class family_handler():
|
||||
ui_defaults.update({
|
||||
"video_prompt_type": "PVBXAKI",
|
||||
"mask_expand": 20,
|
||||
"audio_prompt_type_value": "R",
|
||||
})
|
||||
|
||||
if text_oneframe_overlap(base_model_type):
|
||||
|
||||
@ -32,6 +32,14 @@ def seed_everything(seed: int):
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
def has_video_file_extension(filename):
|
||||
extension = os.path.splitext(filename)[-1].lower()
|
||||
return extension in [".mp4"]
|
||||
|
||||
def has_image_file_extension(filename):
|
||||
extension = os.path.splitext(filename)[-1].lower()
|
||||
return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"]
|
||||
|
||||
def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ):
|
||||
import math
|
||||
|
||||
@ -94,7 +102,7 @@ def get_video_info(video_path):
|
||||
|
||||
return fps, width, height, frame_count
|
||||
|
||||
def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, return_PIL = True) -> torch.Tensor:
|
||||
def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, target_fps = None, return_PIL = True) -> torch.Tensor:
|
||||
"""Extract nth frame from video as PyTorch tensor normalized to [-1, 1]."""
|
||||
cap = cv2.VideoCapture(file_name)
|
||||
|
||||
@ -102,6 +110,9 @@ def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool
|
||||
raise ValueError(f"Cannot open video: {file_name}")
|
||||
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
fps = round(cap.get(cv2.CAP_PROP_FPS))
|
||||
if target_fps is not None:
|
||||
frame_no = round(target_fps * frame_no /fps)
|
||||
|
||||
# Handle out of bounds
|
||||
if frame_no >= total_frames or frame_no < 0:
|
||||
@ -175,10 +186,15 @@ def remove_background(img, session=None):
|
||||
def convert_image_to_tensor(image):
|
||||
return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0)
|
||||
|
||||
def convert_tensor_to_image(t, frame_no = 0):
|
||||
def convert_tensor_to_image(t, frame_no = 0, mask_levels = False):
|
||||
if len(t.shape) == 4:
|
||||
t = t[:, frame_no]
|
||||
return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
|
||||
if t.shape[0]== 1:
|
||||
t = t.expand(3,-1,-1)
|
||||
if mask_levels:
|
||||
return Image.fromarray(t.clone().mul_(255).permute(1,2,0).to(torch.uint8).cpu().numpy())
|
||||
else:
|
||||
return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
|
||||
|
||||
def save_image(tensor_image, name, frame_no = -1):
|
||||
convert_tensor_to_image(tensor_image, frame_no).save(name)
|
||||
@ -257,7 +273,7 @@ def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fi
|
||||
image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
||||
return image, new_height, new_width
|
||||
|
||||
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5 ):
|
||||
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5, return_tensor = False ):
|
||||
if rm_background:
|
||||
session = new_session()
|
||||
|
||||
@ -266,7 +282,7 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
|
||||
for i, img in enumerate(img_list):
|
||||
width, height = img.size
|
||||
resized_mask = None
|
||||
if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2:
|
||||
if any_background_ref == 1 and i==0 or any_background_ref == 2:
|
||||
if outpainting_dims is not None and background_ref_outpainted:
|
||||
resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True)
|
||||
elif img.size != (budget_width, budget_height):
|
||||
@ -291,7 +307,10 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
|
||||
if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) :
|
||||
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
||||
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
||||
output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
|
||||
if return_tensor:
|
||||
output_list.append(convert_image_to_tensor(resized_image).unsqueeze(1))
|
||||
else:
|
||||
output_list.append(resized_image)
|
||||
output_mask_list.append(resized_mask)
|
||||
return output_list, output_mask_list
|
||||
|
||||
@ -346,47 +365,46 @@ def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu
|
||||
|
||||
return ref_img.to(device), canvas
|
||||
|
||||
def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, extract_guide_from_window_start = False, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None ):
|
||||
def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None, device ="cpu"):
|
||||
src_videos, src_masks = [], []
|
||||
inpaint_color = guide_inpaint_color/127.5 - 1
|
||||
prepend_count = pre_video_guide.shape[1] if not extract_guide_from_window_start and pre_video_guide is not None else 0
|
||||
inpaint_color_compressed = guide_inpaint_color/127.5 - 1
|
||||
prepend_count = pre_video_guide.shape[1] if pre_video_guide is not None else 0
|
||||
for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)):
|
||||
src_video = src_mask = None
|
||||
if cur_video_guide is not None:
|
||||
src_video = cur_video_guide.permute(3, 0, 1, 2).float().div_(127.5).sub_(1.) # c, f, h, w
|
||||
if cur_video_mask is not None and any_mask:
|
||||
src_mask = cur_video_mask.permute(3, 0, 1, 2).float().div_(255)[0:1] # c, f, h, w
|
||||
if pre_video_guide is not None and not extract_guide_from_window_start:
|
||||
src_video, src_mask = cur_video_guide, cur_video_mask
|
||||
if pre_video_guide is not None:
|
||||
src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1)
|
||||
if any_mask:
|
||||
src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1)
|
||||
if src_video is None:
|
||||
if any_guide_padding:
|
||||
src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color, dtype = torch.float, device= "cpu")
|
||||
if any_mask:
|
||||
src_mask = torch.zeros_like(src_video[0:1])
|
||||
elif src_video.shape[1] < current_video_length:
|
||||
if any_guide_padding:
|
||||
src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color, dtype = src_video.dtype, device= src_video.device) ], dim=1)
|
||||
if cur_video_mask is not None and any_mask:
|
||||
src_mask = torch.cat([src_mask, torch.full( (1, current_video_length - src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1)
|
||||
|
||||
if any_guide_padding:
|
||||
if src_video is None:
|
||||
src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color_compressed, dtype = torch.float, device= device)
|
||||
elif src_video.shape[1] < current_video_length:
|
||||
src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color_compressed, dtype = src_video.dtype, device= src_video.device) ], dim=1)
|
||||
elif src_video is not None:
|
||||
new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1
|
||||
src_video = src_video[:, :new_num_frames]
|
||||
|
||||
if any_mask and src_video is not None:
|
||||
if src_mask is None:
|
||||
src_mask = torch.ones_like(src_video[:1])
|
||||
elif src_mask.shape[1] < src_video.shape[1]:
|
||||
src_mask = torch.cat([src_mask, torch.full( (1, src_video.shape[1]- src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1)
|
||||
else:
|
||||
new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1
|
||||
src_video = src_video[:, :new_num_frames]
|
||||
if any_mask:
|
||||
src_mask = src_mask[:, :new_num_frames]
|
||||
src_mask = src_mask[:, :src_video.shape[1]]
|
||||
|
||||
for k, keep in enumerate(keep_video_guide_frames):
|
||||
if not keep:
|
||||
pos = prepend_count + k
|
||||
src_video[:, pos:pos+1] = inpaint_color
|
||||
src_mask[:, pos:pos+1] = 1
|
||||
|
||||
for k, frame in enumerate(inject_frames):
|
||||
if frame != None:
|
||||
pos = prepend_count + k
|
||||
src_video[:, pos:pos+1], src_mask[:, pos:pos+1] = fit_image_into_canvas(frame, image_size, inpaint_color, device, True, outpainting_dims, return_mask= True)
|
||||
if src_video is not None :
|
||||
for k, keep in enumerate(keep_video_guide_frames):
|
||||
if not keep:
|
||||
pos = prepend_count + k
|
||||
src_video[:, pos:pos+1] = inpaint_color_compressed
|
||||
if any_mask: src_mask[:, pos:pos+1] = 1
|
||||
|
||||
for k, frame in enumerate(inject_frames):
|
||||
if frame != None:
|
||||
pos = prepend_count + k
|
||||
src_video[:, pos:pos+1], msk = fit_image_into_canvas(frame, image_size, guide_inpaint_color, device, True, outpainting_dims, return_mask= any_mask)
|
||||
if any_mask: src_mask[:, pos:pos+1] = msk
|
||||
src_videos.append(src_video)
|
||||
src_masks.append(src_mask)
|
||||
return src_videos, src_masks
|
||||
|
||||
396
wgp.py
396
wgp.py
@ -24,6 +24,7 @@ from shared.utils import notification_sound
|
||||
from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers
|
||||
from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background, rgb_bw_to_rgba_mask
|
||||
from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions
|
||||
from shared.utils.utils import has_video_file_extension, has_image_file_extension
|
||||
from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image
|
||||
from shared.utils.audio_video import save_image_metadata, read_image_metadata
|
||||
from shared.match_archi import match_nvidia_architecture
|
||||
@ -62,7 +63,7 @@ AUTOSAVE_FILENAME = "queue.zip"
|
||||
PROMPT_VARS_MAX = 10
|
||||
|
||||
target_mmgp_version = "3.6.0"
|
||||
WanGP_version = "8.61"
|
||||
WanGP_version = "8.7"
|
||||
settings_version = 2.35
|
||||
max_source_video_frames = 3000
|
||||
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
|
||||
@ -1942,7 +1943,8 @@ def get_model_min_frames_and_step(model_type):
|
||||
mode_def = get_model_def(model_type)
|
||||
frames_minimum = mode_def.get("frames_minimum", 5)
|
||||
frames_steps = mode_def.get("frames_steps", 4)
|
||||
return frames_minimum, frames_steps
|
||||
latent_size = mode_def.get("latent_size", frames_steps)
|
||||
return frames_minimum, frames_steps, latent_size
|
||||
|
||||
def get_model_fps(model_type):
|
||||
mode_def = get_model_def(model_type)
|
||||
@ -3459,7 +3461,7 @@ def select_video(state, input_file_list, event_data: gr.EventData):
|
||||
if len(video_other_prompts) >0 :
|
||||
values += [video_other_prompts]
|
||||
labels += ["Other Prompts"]
|
||||
if len(video_outpainting) >0 and any_letters(video_image_prompt_type, "VFK"):
|
||||
if len(video_outpainting) >0:
|
||||
values += [video_outpainting]
|
||||
labels += ["Outpainting"]
|
||||
video_sample_solver = configs.get("sample_solver", "")
|
||||
@ -3532,6 +3534,11 @@ def convert_image(image):
|
||||
return cast(Image, ImageOps.exif_transpose(image))
|
||||
|
||||
def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='torch'):
|
||||
if isinstance(video_in, str) and has_image_file_extension(video_in):
|
||||
video_in = Image.open(video_in)
|
||||
if isinstance(video_in, Image.Image):
|
||||
return torch.from_numpy(np.array(video_in).astype(np.uint8)).unsqueeze(0)
|
||||
|
||||
from shared.utils.utils import resample
|
||||
|
||||
import decord
|
||||
@ -3653,19 +3660,22 @@ def get_preprocessor(process_type, inpaint_color):
|
||||
def process_images_multithread(image_processor, items, process_type, wrap_in_list = True, max_workers: int = os.cpu_count()/ 2) :
|
||||
if not items:
|
||||
return []
|
||||
max_workers = 11
|
||||
|
||||
import concurrent.futures
|
||||
start_time = time.time()
|
||||
# print(f"Preprocessus:{process_type} started")
|
||||
if process_type in ["prephase", "upsample"]:
|
||||
if wrap_in_list :
|
||||
items = [ [img] for img in items]
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)}
|
||||
results = [None] * len(items)
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
idx = futures[future]
|
||||
results[idx] = future.result()
|
||||
if max_workers == 1:
|
||||
results = [image_processor(img) for img in items]
|
||||
else:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)}
|
||||
results = [None] * len(items)
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
idx = futures[future]
|
||||
results[idx] = future.result()
|
||||
|
||||
if wrap_in_list:
|
||||
results = [ img[0] for img in results]
|
||||
@ -3677,55 +3687,6 @@ def process_images_multithread(image_processor, items, process_type, wrap_in_lis
|
||||
|
||||
return results
|
||||
|
||||
def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canvas = False, fit_crop = False, block_size= 16, expand_scale = 2, outpainting_dims = None, inpaint_color = 127):
|
||||
frame_width, frame_height = input_image.size
|
||||
|
||||
if fit_crop:
|
||||
input_image = rescale_and_crop(input_image, width, height)
|
||||
if input_mask is not None:
|
||||
input_mask = rescale_and_crop(input_mask, width, height)
|
||||
return input_image, input_mask
|
||||
|
||||
if outpainting_dims != None:
|
||||
if fit_canvas != None:
|
||||
frame_height, frame_width = get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims)
|
||||
else:
|
||||
frame_height, frame_width = height, width
|
||||
|
||||
if fit_canvas != None:
|
||||
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size)
|
||||
|
||||
if outpainting_dims != None:
|
||||
final_height, final_width = height, width
|
||||
height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
|
||||
|
||||
if fit_canvas != None or outpainting_dims != None:
|
||||
input_image = input_image.resize((width, height), resample=Image.Resampling.LANCZOS)
|
||||
if input_mask is not None:
|
||||
input_mask = input_mask.resize((width, height), resample=Image.Resampling.LANCZOS)
|
||||
|
||||
if expand_scale != 0 and input_mask is not None:
|
||||
kernel_size = abs(expand_scale)
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
||||
op_expand = cv2.dilate if expand_scale > 0 else cv2.erode
|
||||
input_mask = np.array(input_mask)
|
||||
input_mask = op_expand(input_mask, kernel, iterations=3)
|
||||
input_mask = Image.fromarray(input_mask)
|
||||
|
||||
if outpainting_dims != None:
|
||||
inpaint_color = inpaint_color / 127.5-1
|
||||
image = convert_image_to_tensor(input_image)
|
||||
full_frame= torch.full( (image.shape[0], final_height, final_width), inpaint_color, dtype= torch.float, device= image.device)
|
||||
full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = image
|
||||
input_image = convert_tensor_to_image(full_frame)
|
||||
|
||||
if input_mask is not None:
|
||||
mask = convert_image_to_tensor(input_mask)
|
||||
full_frame= torch.full( (mask.shape[0], final_height, final_width), 1, dtype= torch.float, device= mask.device)
|
||||
full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = mask
|
||||
input_mask = convert_tensor_to_image(full_frame)
|
||||
|
||||
return input_image, input_mask
|
||||
def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512):
|
||||
if not input_video_path or max_frames <= 0:
|
||||
return None, None
|
||||
@ -3780,6 +3741,8 @@ def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_fr
|
||||
save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None)
|
||||
return face_tensor
|
||||
|
||||
def get_default_workers():
|
||||
return os.cpu_count()/ 2
|
||||
|
||||
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1):
|
||||
|
||||
@ -3906,8 +3869,8 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
|
||||
return (target_frame, frame, mask)
|
||||
else:
|
||||
return (target_frame, None, None)
|
||||
|
||||
proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False)
|
||||
max_workers = get_default_workers()
|
||||
proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False, max_workers=max_workers)
|
||||
proc_list, proc_list_outside, proc_mask = [None] * len(proc_lists), [None] * len(proc_lists), [None] * len(proc_lists)
|
||||
for frame_idx, frame_group in enumerate(proc_lists):
|
||||
proc_list[frame_idx], proc_list_outside[frame_idx], proc_mask[frame_idx] = frame_group
|
||||
@ -3916,11 +3879,11 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
|
||||
mask_video = None
|
||||
|
||||
if preproc2 != None:
|
||||
proc_list2 = process_images_multithread(preproc2, proc_list, process_type2)
|
||||
proc_list2 = process_images_multithread(preproc2, proc_list, process_type2, max_workers=max_workers)
|
||||
#### to be finished ...or not
|
||||
proc_list = process_images_multithread(preproc, proc_list, process_type)
|
||||
proc_list = process_images_multithread(preproc, proc_list, process_type, max_workers=max_workers)
|
||||
if any_mask:
|
||||
proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask)
|
||||
proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask, max_workers=max_workers)
|
||||
else:
|
||||
proc_list_outside = proc_mask = len(proc_list) * [None]
|
||||
|
||||
@ -3938,7 +3901,7 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
|
||||
full_frame= torch.full( (final_height, final_width, mask.shape[-1]), 255, dtype= torch.uint8, device= mask.device)
|
||||
full_frame[margin_top:margin_top+height, margin_left:margin_left+width] = mask
|
||||
mask = full_frame
|
||||
masks.append(mask)
|
||||
masks.append(mask[:, :, 0:1].clone())
|
||||
else:
|
||||
masked_frame = processed_img
|
||||
|
||||
@ -3958,13 +3921,13 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
|
||||
proc_list[frame_no] = proc_list_outside[frame_no] = proc_mask[frame_no] = None
|
||||
|
||||
|
||||
if args.save_masks:
|
||||
from preprocessing.dwpose.pose import save_one_video
|
||||
saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ]
|
||||
save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None)
|
||||
if any_mask:
|
||||
saved_masks = [mask.cpu().numpy() for mask in masks ]
|
||||
save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None)
|
||||
# if args.save_masks:
|
||||
# from preprocessing.dwpose.pose import save_one_video
|
||||
# saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ]
|
||||
# save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None)
|
||||
# if any_mask:
|
||||
# saved_masks = [mask.cpu().numpy() for mask in masks ]
|
||||
# save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None)
|
||||
preproc = None
|
||||
preproc_outside = None
|
||||
gc.collect()
|
||||
@ -3972,8 +3935,10 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
|
||||
if pad_frames > 0:
|
||||
masked_frames = masked_frames[0] * pad_frames + masked_frames
|
||||
if any_mask: masked_frames = masks[0] * pad_frames + masks
|
||||
masked_frames = torch.stack(masked_frames).permute(-1,0,1,2).float().div_(127.5).sub_(1.)
|
||||
masks = torch.stack(masks).permute(-1,0,1,2).float().div_(255) if any_mask else None
|
||||
|
||||
return torch.stack(masked_frames), torch.stack(masks) if any_mask else None
|
||||
return masked_frames, masks
|
||||
|
||||
def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size = 16):
|
||||
|
||||
@ -4102,7 +4067,7 @@ def perform_spatial_upsampling(sample, spatial_upsampling):
|
||||
frames_to_upsample = [sample[:, i] for i in range( sample.shape[1]) ]
|
||||
def upsample_frames(frame):
|
||||
return resize_lanczos(frame, h, w).unsqueeze(1)
|
||||
sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False), dim=1)
|
||||
sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False, max_workers=get_default_workers()), dim=1)
|
||||
frames_to_upsample = None
|
||||
return sample
|
||||
|
||||
@ -4609,17 +4574,13 @@ def generate_video(
|
||||
batch_size = 1
|
||||
temp_filenames_list = []
|
||||
|
||||
convert_image_guide_to_video = model_def.get("convert_image_guide_to_video", False)
|
||||
if convert_image_guide_to_video:
|
||||
if image_guide is not None and isinstance(image_guide, Image.Image):
|
||||
video_guide = convert_image_to_video(image_guide)
|
||||
temp_filenames_list.append(video_guide)
|
||||
image_guide = None
|
||||
if image_guide is not None and isinstance(image_guide, Image.Image):
|
||||
video_guide = image_guide
|
||||
image_guide = None
|
||||
|
||||
if image_mask is not None and isinstance(image_mask, Image.Image):
|
||||
video_mask = convert_image_to_video(image_mask)
|
||||
temp_filenames_list.append(video_mask)
|
||||
image_mask = None
|
||||
if image_mask is not None and isinstance(image_mask, Image.Image):
|
||||
video_mask = image_mask
|
||||
image_mask = None
|
||||
|
||||
if model_def.get("no_background_removal", False): remove_background_images_ref = 0
|
||||
|
||||
@ -4711,22 +4672,12 @@ def generate_video(
|
||||
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
|
||||
guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5)
|
||||
extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False)
|
||||
i2v = test_class_i2v(model_type)
|
||||
diffusion_forcing = "diffusion_forcing" in model_filename
|
||||
t2v = base_model_type in ["t2v"]
|
||||
ltxv = "ltxv" in model_filename
|
||||
vace = test_vace_module(base_model_type)
|
||||
hunyuan_t2v = "hunyuan_video_720" in model_filename
|
||||
hunyuan_i2v = "hunyuan_video_i2v" in model_filename
|
||||
hunyuan_custom = "hunyuan_video_custom" in model_filename
|
||||
hunyuan_custom_audio = hunyuan_custom and "audio" in model_filename
|
||||
hunyuan_custom_edit = hunyuan_custom and "edit" in model_filename
|
||||
hunyuan_avatar = "hunyuan_video_avatar" in model_filename
|
||||
fantasy = base_model_type in ["fantasy"]
|
||||
multitalk = model_def.get("multitalk_class", False)
|
||||
standin = model_def.get("standin_class", False)
|
||||
infinitetalk = base_model_type in ["infinitetalk"]
|
||||
animate = base_model_type in ["animate"]
|
||||
|
||||
if "B" in audio_prompt_type or "X" in audio_prompt_type:
|
||||
from models.wan.multitalk.multitalk import parse_speakers_locations
|
||||
@ -4763,9 +4714,9 @@ def generate_video(
|
||||
sliding_window_size = current_video_length
|
||||
reuse_frames = 0
|
||||
|
||||
_, latent_size = get_model_min_frames_and_step(model_type)
|
||||
if diffusion_forcing: latent_size = 4
|
||||
_, _, latent_size = get_model_min_frames_and_step(model_type)
|
||||
original_image_refs = image_refs
|
||||
image_refs = None if image_refs is None else [] + image_refs # work on a copy as it is going to be modified
|
||||
# image_refs = None
|
||||
# nb_frames_positions= 0
|
||||
# Output Video Ratio Priorities:
|
||||
@ -4889,6 +4840,7 @@ def generate_video(
|
||||
initial_total_windows = 0
|
||||
discard_last_frames = sliding_window_discard_last_frames
|
||||
default_requested_frames_to_generate = current_video_length
|
||||
nb_frames_positions = 0
|
||||
if sliding_window:
|
||||
initial_total_windows= compute_sliding_window_no(default_requested_frames_to_generate, sliding_window_size, discard_last_frames, reuse_frames)
|
||||
current_video_length = sliding_window_size
|
||||
@ -4907,7 +4859,7 @@ def generate_video(
|
||||
if repeat_no >= total_generation: break
|
||||
repeat_no +=1
|
||||
gen["repeat_no"] = repeat_no
|
||||
src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = src_faces = None
|
||||
src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = None
|
||||
prefix_video = pre_video_frame = None
|
||||
source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window
|
||||
source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before )
|
||||
@ -4963,7 +4915,6 @@ def generate_video(
|
||||
return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) )
|
||||
refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {}
|
||||
|
||||
src_ref_images, src_ref_masks = image_refs, None
|
||||
image_start_tensor = image_end_tensor = None
|
||||
if window_no == 1 and (video_source is not None or image_start is not None):
|
||||
if image_start is not None:
|
||||
@ -5020,7 +4971,7 @@ def generate_video(
|
||||
if len(pos) > 0:
|
||||
if pos in ["L", "l"]:
|
||||
cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length
|
||||
if cur_end_pos >= last_frame_no and not joker_used:
|
||||
if cur_end_pos >= last_frame_no-1 and not joker_used:
|
||||
joker_used = True
|
||||
cur_end_pos = last_frame_no -1
|
||||
project_window_no += 1
|
||||
@ -5036,141 +4987,53 @@ def generate_video(
|
||||
frames_to_inject[pos] = image_refs[i]
|
||||
|
||||
|
||||
video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
|
||||
if video_guide is not None:
|
||||
keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
|
||||
if len(error) > 0:
|
||||
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
|
||||
guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame
|
||||
extra_control_frames = model_def.get("extra_control_frames", 0)
|
||||
if extra_control_frames > 0 and aligned_guide_start_frame >= extra_control_frames: guide_frames_extract_start -= extra_control_frames
|
||||
|
||||
keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else []
|
||||
keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ]
|
||||
guide_frames_extract_count = len(keep_frames_parsed)
|
||||
|
||||
# Extract Faces to video
|
||||
if "B" in video_prompt_type:
|
||||
send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")])
|
||||
src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps)
|
||||
if src_faces is not None and src_faces.shape[1] < current_video_length:
|
||||
src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1)
|
||||
|
||||
if vace or animate:
|
||||
video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
|
||||
context_scale = [ control_net_weight]
|
||||
if "V" in video_prompt_type:
|
||||
process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
|
||||
preprocess_type, preprocess_type2 = "raw", None
|
||||
for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")):
|
||||
if process_num == 0:
|
||||
preprocess_type = process_map_video_guide.get(process_letter, "raw")
|
||||
else:
|
||||
preprocess_type2 = process_map_video_guide.get(process_letter, None)
|
||||
status_info = "Extracting " + processes_names[preprocess_type]
|
||||
extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask])
|
||||
if len(extra_process_list) == 1:
|
||||
status_info += " and " + processes_names[extra_process_list[0]]
|
||||
elif len(extra_process_list) == 2:
|
||||
status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]]
|
||||
if preprocess_type2 is not None:
|
||||
context_scale = [ control_net_weight /2, control_net_weight2 /2]
|
||||
send_cmd("progress", [0, get_latest_status(state, status_info)])
|
||||
inpaint_color = 0 if preprocess_type=="pose" else guide_inpaint_color
|
||||
video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color )
|
||||
if preprocess_type2 != None:
|
||||
video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 )
|
||||
# Sparse Video to Video
|
||||
sparse_video_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, target_fps = fps, return_PIL = True) if "R" in video_prompt_type else None
|
||||
|
||||
if video_guide_processed != None:
|
||||
if sample_fit_canvas != None:
|
||||
image_size = video_guide_processed.shape[-3: -1]
|
||||
sample_fit_canvas = None
|
||||
refresh_preview["video_guide"] = Image.fromarray(video_guide_processed[0].cpu().numpy())
|
||||
if video_guide_processed2 != None:
|
||||
refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())]
|
||||
if video_mask_processed != None:
|
||||
refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy())
|
||||
|
||||
frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame]
|
||||
|
||||
if not vace and (any_letters(video_prompt_type ,"FV") or model_def.get("forced_guide_mask_inputs", False)):
|
||||
any_mask = True
|
||||
any_guide_padding = model_def.get("pad_guide_video", False)
|
||||
from shared.utils.utils import prepare_video_guide_and_mask
|
||||
src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed, video_guide_processed2],
|
||||
[video_mask_processed, video_mask_processed2],
|
||||
pre_video_guide, image_size, current_video_length, latent_size,
|
||||
any_mask, any_guide_padding, guide_inpaint_color, extract_guide_from_window_start,
|
||||
keep_frames_parsed, frames_to_inject_parsed , outpainting_dims)
|
||||
|
||||
src_video, src_video2 = src_videos
|
||||
src_mask, src_mask2 = src_masks
|
||||
if src_video is None:
|
||||
abort = True
|
||||
break
|
||||
if src_faces is not None:
|
||||
if src_faces.shape[1] < src_video.shape[1]:
|
||||
src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1)
|
||||
else:
|
||||
src_faces = src_faces[:, :src_video.shape[1]]
|
||||
if args.save_masks:
|
||||
save_video( src_video, "masked_frames.mp4", fps)
|
||||
if src_video2 is not None:
|
||||
save_video( src_video2, "masked_frames2.mp4", fps)
|
||||
if any_mask:
|
||||
save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
|
||||
|
||||
elif ltxv:
|
||||
preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
|
||||
status_info = "Extracting " + processes_names[preprocess_type]
|
||||
send_cmd("progress", [0, get_latest_status(state, status_info)])
|
||||
# start one frame ealier to facilitate latents merging later
|
||||
src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size )
|
||||
if src_video != None:
|
||||
src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ]
|
||||
refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
|
||||
refresh_preview["video_mask"] = None
|
||||
src_video = src_video.permute(3, 0, 1, 2)
|
||||
src_video = src_video.float().div_(127.5).sub_(1.) # c, f, h, w
|
||||
if sample_fit_canvas != None:
|
||||
image_size = src_video.shape[-2:]
|
||||
sample_fit_canvas = None
|
||||
|
||||
elif hunyuan_custom_edit:
|
||||
if "P" in video_prompt_type:
|
||||
progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")]
|
||||
# Generic Video Preprocessing
|
||||
process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
|
||||
preprocess_type, preprocess_type2 = "raw", None
|
||||
for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")):
|
||||
if process_num == 0:
|
||||
preprocess_type = process_map_video_guide.get(process_letter, "raw")
|
||||
else:
|
||||
progress_args = [0, get_latest_status(state,"Extracting Video and Mask")]
|
||||
preprocess_type2 = process_map_video_guide.get(process_letter, None)
|
||||
status_info = "Extracting " + processes_names[preprocess_type]
|
||||
extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask])
|
||||
if len(extra_process_list) == 1:
|
||||
status_info += " and " + processes_names[extra_process_list[0]]
|
||||
elif len(extra_process_list) == 2:
|
||||
status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]]
|
||||
context_scale = [control_net_weight /2, control_net_weight2 /2] if preprocess_type2 is not None else [control_net_weight]
|
||||
if not (preprocess_type == "identity" and preprocess_type2 is None and video_mask is None):send_cmd("progress", [0, get_latest_status(state, status_info)])
|
||||
inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask == "inpaint" else guide_inpaint_color
|
||||
video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide if sparse_video_image is None else sparse_video_image, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color, block_size = block_size )
|
||||
if preprocess_type2 != None:
|
||||
video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2, block_size = block_size )
|
||||
|
||||
send_cmd("progress", progress_args)
|
||||
src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0)
|
||||
refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
|
||||
if src_mask != None:
|
||||
refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy())
|
||||
|
||||
elif "R" in video_prompt_type: # sparse video to video
|
||||
src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True)
|
||||
src_image, _, _ = calculate_dimensions_and_resize_image(src_image, image_size[0], image_size[1 ], sample_fit_canvas, fit_crop, block_size = block_size)
|
||||
refresh_preview["video_guide"] = src_image
|
||||
src_video = convert_image_to_tensor(src_image).unsqueeze(1)
|
||||
if sample_fit_canvas != None:
|
||||
image_size = src_video.shape[-2:]
|
||||
sample_fit_canvas = None
|
||||
|
||||
else: # video to video
|
||||
video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size= block_size)
|
||||
if video_guide_processed is None:
|
||||
src_video = pre_video_guide
|
||||
else:
|
||||
if sample_fit_canvas != None:
|
||||
image_size = video_guide_processed.shape[-3: -1]
|
||||
sample_fit_canvas = None
|
||||
src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2)
|
||||
if pre_video_guide != None:
|
||||
src_video = torch.cat( [pre_video_guide, src_video], dim=1)
|
||||
elif image_guide is not None:
|
||||
new_image_guide, new_image_mask = preprocess_image_with_mask(image_guide, image_mask, image_size[0], image_size[1], fit_canvas = sample_fit_canvas, fit_crop= fit_crop, block_size= block_size, expand_scale = mask_expand, outpainting_dims=outpainting_dims)
|
||||
if sample_fit_canvas is not None:
|
||||
image_size = (new_image_guide.size[1], new_image_guide.size[0])
|
||||
if video_guide_processed is not None and sample_fit_canvas is not None:
|
||||
image_size = video_guide_processed.shape[-2:]
|
||||
sample_fit_canvas = None
|
||||
refresh_preview["image_guide"] = new_image_guide
|
||||
if new_image_mask is not None:
|
||||
refresh_preview["image_mask"] = new_image_mask
|
||||
|
||||
if window_no == 1 and image_refs is not None and len(image_refs) > 0:
|
||||
if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) :
|
||||
@ -5193,44 +5056,67 @@ def generate_video(
|
||||
refresh_preview["image_refs"] = image_refs
|
||||
|
||||
if len(image_refs) > nb_frames_positions:
|
||||
src_ref_images = image_refs[nb_frames_positions:]
|
||||
if remove_background_images_ref > 0:
|
||||
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
|
||||
# keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested
|
||||
image_refs[nb_frames_positions:], src_ref_masks = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0],
|
||||
|
||||
src_ref_images, src_ref_masks = resize_and_remove_background(src_ref_images , image_size[1], image_size[0],
|
||||
remove_background_images_ref > 0, any_background_ref,
|
||||
fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1,
|
||||
fit_into_canvas= model_def.get("fit_into_canvas_image_refs", 1),
|
||||
block_size=block_size,
|
||||
outpainting_dims =outpainting_dims,
|
||||
background_ref_outpainted = model_def.get("background_ref_outpainted", True) )
|
||||
refresh_preview["image_refs"] = image_refs
|
||||
background_ref_outpainted = model_def.get("background_ref_outpainted", True),
|
||||
return_tensor= model_def.get("return_image_refs_tensor", False) )
|
||||
|
||||
|
||||
if vace :
|
||||
image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
|
||||
|
||||
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2],
|
||||
[video_mask_processed] if video_guide_processed2 == None else [video_mask_processed, video_mask_processed2],
|
||||
[image_refs_copy] if video_guide_processed2 == None else [image_refs_copy, image_refs_copy],
|
||||
current_video_length, image_size = image_size, device ="cpu",
|
||||
keep_video_guide_frames=keep_frames_parsed,
|
||||
pre_src_video = [pre_video_guide] if video_guide_processed2 == None else [pre_video_guide, pre_video_guide],
|
||||
inject_frames= frames_to_inject_parsed,
|
||||
outpainting_dims = outpainting_dims,
|
||||
any_background_ref = any_background_ref
|
||||
)
|
||||
if len(frames_to_inject_parsed) or any_background_ref:
|
||||
new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + 0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
|
||||
if any_background_ref:
|
||||
new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:]
|
||||
frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame]
|
||||
if video_guide is not None or len(frames_to_inject_parsed) > 0 or model_def.get("forced_guide_mask_inputs", False):
|
||||
any_mask = video_mask is not None or model_def.get("forced_guide_mask_inputs", False)
|
||||
any_guide_padding = model_def.get("pad_guide_video", False)
|
||||
from shared.utils.utils import prepare_video_guide_and_mask
|
||||
src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed] + ([] if video_guide_processed2 is None else [video_guide_processed2]),
|
||||
[video_mask_processed] + ([] if video_mask_processed2 is None else [video_mask_processed2]),
|
||||
None if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else pre_video_guide,
|
||||
image_size, current_video_length, latent_size,
|
||||
any_mask, any_guide_padding, guide_inpaint_color,
|
||||
keep_frames_parsed, frames_to_inject_parsed , outpainting_dims)
|
||||
video_guide_processed = video_guide_processed2 = video_mask_processed = video_mask_processed2 = None
|
||||
if len(src_videos) == 1:
|
||||
src_video, src_video2, src_mask, src_mask2 = src_videos[0], None, src_masks[0], None
|
||||
else:
|
||||
src_video, src_video2 = src_videos
|
||||
src_mask, src_mask2 = src_masks
|
||||
src_videos = src_masks = None
|
||||
if src_video is None:
|
||||
abort = True
|
||||
break
|
||||
if src_faces is not None:
|
||||
if src_faces.shape[1] < src_video.shape[1]:
|
||||
src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1)
|
||||
else:
|
||||
new_image_refs += image_refs[nb_frames_positions:]
|
||||
refresh_preview["image_refs"] = new_image_refs
|
||||
new_image_refs = None
|
||||
|
||||
if sample_fit_canvas != None:
|
||||
image_size = src_video[0].shape[-2:]
|
||||
sample_fit_canvas = None
|
||||
src_faces = src_faces[:, :src_video.shape[1]]
|
||||
if video_guide is not None or len(frames_to_inject_parsed) > 0:
|
||||
if args.save_masks:
|
||||
if src_video is not None: save_video( src_video, "masked_frames.mp4", fps)
|
||||
if src_video2 is not None: save_video( src_video2, "masked_frames2.mp4", fps)
|
||||
if any_mask: save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
|
||||
if video_guide is not None:
|
||||
preview_frame_no = 0 if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else (guide_start_frame - window_start_frame)
|
||||
refresh_preview["video_guide"] = convert_tensor_to_image(src_video, preview_frame_no)
|
||||
if src_video2 is not None:
|
||||
refresh_preview["video_guide"] = [refresh_preview["video_guide"], convert_tensor_to_image(src_video2, preview_frame_no)]
|
||||
if src_mask is not None and video_mask is not None:
|
||||
refresh_preview["video_mask"] = convert_tensor_to_image(src_mask, preview_frame_no, mask_levels = True)
|
||||
|
||||
if src_ref_images is not None or nb_frames_positions:
|
||||
if len(frames_to_inject_parsed):
|
||||
new_image_refs = [convert_tensor_to_image(src_video, frame_no + (0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame)) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
|
||||
else:
|
||||
new_image_refs = []
|
||||
if src_ref_images is not None:
|
||||
new_image_refs += [convert_tensor_to_image(img) if torch.is_tensor(img) else img for img in src_ref_images ]
|
||||
refresh_preview["image_refs"] = new_image_refs
|
||||
new_image_refs = None
|
||||
|
||||
if len(refresh_preview) > 0:
|
||||
new_inputs= locals()
|
||||
@ -5339,8 +5225,6 @@ def generate_video(
|
||||
pre_video_frame = pre_video_frame,
|
||||
original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [],
|
||||
image_refs_relative_size = image_refs_relative_size,
|
||||
image_guide= new_image_guide,
|
||||
image_mask= new_image_mask,
|
||||
outpainting_dims = outpainting_dims,
|
||||
)
|
||||
except Exception as e:
|
||||
@ -6320,7 +6204,10 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
|
||||
pop += ["image_refs_relative_size"]
|
||||
|
||||
if not vace:
|
||||
pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2"]
|
||||
pop += ["frames_positions", "control_net_weight", "control_net_weight2"]
|
||||
|
||||
if model_def.get("video_guide_outpainting", None) is None:
|
||||
pop += ["video_guide_outpainting"]
|
||||
|
||||
if not (vace or t2v):
|
||||
pop += ["min_frames_if_references"]
|
||||
@ -6506,13 +6393,6 @@ def eject_video_from_gallery(state, input_file_list, choice):
|
||||
choice = min(choice, len(file_list))
|
||||
return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0)
|
||||
|
||||
def has_video_file_extension(filename):
|
||||
extension = os.path.splitext(filename)[-1].lower()
|
||||
return extension in [".mp4"]
|
||||
|
||||
def has_image_file_extension(filename):
|
||||
extension = os.path.splitext(filename)[-1].lower()
|
||||
return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"]
|
||||
def add_videos_to_gallery(state, input_file_list, choice, files_to_load):
|
||||
gen = get_gen_info(state)
|
||||
if files_to_load == None:
|
||||
@ -7881,7 +7761,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
elif recammaster:
|
||||
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", get_max_frames(81)), step=4, label="Number of frames (16 = 1s), locked", interactive= False, visible = True)
|
||||
else:
|
||||
min_frames, frames_step = get_model_min_frames_and_step(base_model_type)
|
||||
min_frames, frames_step, _ = get_model_min_frames_and_step(base_model_type)
|
||||
|
||||
current_video_length = ui_defaults.get("video_length", 81 if get_model_family(base_model_type)=="wan" else 97)
|
||||
|
||||
@ -8059,7 +7939,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
MMAudio_neg_prompt = gr.Text(ui_defaults.get("MMAudio_neg_prompt", ""), label="Negative Prompt (1 or 2 keywords)")
|
||||
|
||||
|
||||
with gr.Column(visible = (t2v or vace) and not fantasy) as audio_prompt_type_remux_row:
|
||||
with gr.Column(visible = any_control_video) as audio_prompt_type_remux_row:
|
||||
gr.Markdown("<B>You may transfer the existing audio tracks of a Control Video</B>")
|
||||
audio_prompt_type_remux = gr.Dropdown(
|
||||
choices=[
|
||||
@ -8284,16 +8164,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
video_info = gr.HTML(visible=True, min_height=100, value=get_default_video_info())
|
||||
with gr.Row(**default_visibility) as video_buttons_row:
|
||||
video_info_extract_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm")
|
||||
video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video )
|
||||
video_info_to_video_source_btn = gr.Button("To Video Source", min_width= 1, size ="sm", visible = any_video_source)
|
||||
video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video )
|
||||
video_info_eject_video_btn = gr.Button("Eject Video", min_width= 1, size ="sm")
|
||||
with gr.Row(**default_visibility) as image_buttons_row:
|
||||
video_info_extract_image_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm")
|
||||
video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", min_width= 1, visible = any_start_image )
|
||||
video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", min_width= 1, visible = any_end_image)
|
||||
video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image )
|
||||
video_info_to_image_mask_btn = gr.Button("To Mask Image", min_width= 1, size ="sm", visible = any_image_mask and False)
|
||||
video_info_to_reference_image_btn = gr.Button("To Reference Image", min_width= 1, size ="sm", visible = any_reference_image)
|
||||
video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image )
|
||||
video_info_eject_image_btn = gr.Button("Eject Image", min_width= 1, size ="sm")
|
||||
with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab:
|
||||
with gr.Group(elem_classes= "postprocess"):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user