mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Vace Contenders are in Town
This commit is contained in:
parent
84010bd861
commit
e28c95ae91
@ -20,6 +20,14 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
|
|||||||
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
|
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
|
||||||
|
|
||||||
## 🔥 Latest Updates :
|
## 🔥 Latest Updates :
|
||||||
|
### September 23 2025: WanGP v8.7 - Here Are Two New Contenders in the Vace Arena !
|
||||||
|
|
||||||
|
So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
|
||||||
|
- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion tranfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
|
||||||
|
|
||||||
|
- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package.
|
||||||
|
|
||||||
|
|
||||||
### September 15 2025: WanGP v8.6 - Attack of the Clones
|
### September 15 2025: WanGP v8.6 - Attack of the Clones
|
||||||
|
|
||||||
- The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**)
|
- The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**)
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
"model": {
|
"model": {
|
||||||
"name": "Wan2.2 Lucy Edit 5B",
|
"name": "Wan2.2 Lucy Edit 5B",
|
||||||
"architecture": "lucy_edit",
|
"architecture": "lucy_edit",
|
||||||
"description": "Lucy Edit Dev is a video editing model that performs instruction-guided edits on videos using free-text prompts \u2014 it supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.",
|
"description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.",
|
||||||
"URLs": [
|
"URLs": [
|
||||||
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors",
|
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors",
|
||||||
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors",
|
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors",
|
||||||
@ -10,6 +10,7 @@
|
|||||||
],
|
],
|
||||||
"group": "wan2_2"
|
"group": "wan2_2"
|
||||||
},
|
},
|
||||||
|
"prompt": "change the clothes to red",
|
||||||
"video_length": 81,
|
"video_length": 81,
|
||||||
"guidance_scale": 5,
|
"guidance_scale": 5,
|
||||||
"flow_shift": 5,
|
"flow_shift": 5,
|
||||||
|
|||||||
16
defaults/lucy_edit_fastwan.json
Normal file
16
defaults/lucy_edit_fastwan.json
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"name": "Wan2.2 FastWan Lucy Edit 5B",
|
||||||
|
"architecture": "lucy_edit",
|
||||||
|
"description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.",
|
||||||
|
"URLs": "lucy_edit",
|
||||||
|
"group": "wan2_2",
|
||||||
|
"loras": "ti2v_2_2"
|
||||||
|
},
|
||||||
|
"prompt": "change the clothes to red",
|
||||||
|
"video_length": 81,
|
||||||
|
"guidance_scale": 1,
|
||||||
|
"flow_shift": 3,
|
||||||
|
"num_inference_steps": 5,
|
||||||
|
"resolution": "1280x720"
|
||||||
|
}
|
||||||
@ -56,7 +56,7 @@ class family_handler():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
extra_model_def["lock_image_refs_ratios"] = True
|
extra_model_def["fit_into_canvas_image_refs"] = 0
|
||||||
|
|
||||||
return extra_model_def
|
return extra_model_def
|
||||||
|
|
||||||
|
|||||||
@ -142,8 +142,8 @@ class model_factory:
|
|||||||
n_prompt: str = None,
|
n_prompt: str = None,
|
||||||
sampling_steps: int = 20,
|
sampling_steps: int = 20,
|
||||||
input_ref_images = None,
|
input_ref_images = None,
|
||||||
image_guide= None,
|
input_frames= None,
|
||||||
image_mask= None,
|
input_masks= None,
|
||||||
width= 832,
|
width= 832,
|
||||||
height=480,
|
height=480,
|
||||||
embedded_guidance_scale: float = 2.5,
|
embedded_guidance_scale: float = 2.5,
|
||||||
@ -197,10 +197,12 @@ class model_factory:
|
|||||||
for new_img in input_ref_images[1:]:
|
for new_img in input_ref_images[1:]:
|
||||||
stiched = stitch_images(stiched, new_img)
|
stiched = stitch_images(stiched, new_img)
|
||||||
input_ref_images = [stiched]
|
input_ref_images = [stiched]
|
||||||
elif image_guide is not None:
|
elif input_frames is not None:
|
||||||
input_ref_images = [image_guide]
|
input_ref_images = [convert_tensor_to_image(input_frames) ]
|
||||||
else:
|
else:
|
||||||
input_ref_images = None
|
input_ref_images = None
|
||||||
|
image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True)
|
||||||
|
|
||||||
|
|
||||||
if self.name in ['flux-dev-uso', 'flux-dev-umo'] :
|
if self.name in ['flux-dev-uso', 'flux-dev-umo'] :
|
||||||
inp, height, width = prepare_multi_ip(
|
inp, height, width = prepare_multi_ip(
|
||||||
@ -253,8 +255,8 @@ class model_factory:
|
|||||||
if image_mask is not None:
|
if image_mask is not None:
|
||||||
from shared.utils.utils import convert_image_to_tensor
|
from shared.utils.utils import convert_image_to_tensor
|
||||||
img_msk_rebuilt = inp["img_msk_rebuilt"]
|
img_msk_rebuilt = inp["img_msk_rebuilt"]
|
||||||
img= convert_image_to_tensor(image_guide)
|
img= input_frames.squeeze(1).unsqueeze(0) # convert_image_to_tensor(image_guide)
|
||||||
x = img.squeeze(2) * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt
|
x = img * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt
|
||||||
|
|
||||||
x = x.clamp(-1, 1)
|
x = x.clamp(-1, 1)
|
||||||
x = x.transpose(0, 1)
|
x = x.transpose(0, 1)
|
||||||
|
|||||||
@ -865,7 +865,7 @@ class HunyuanVideoSampler(Inference):
|
|||||||
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
|
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
|
||||||
else:
|
else:
|
||||||
if input_frames != None:
|
if input_frames != None:
|
||||||
target_height, target_width = input_frames.shape[-3:-1]
|
target_height, target_width = input_frames.shape[-2:]
|
||||||
elif input_video != None:
|
elif input_video != None:
|
||||||
target_height, target_width = input_video.shape[-2:]
|
target_height, target_width = input_video.shape[-2:]
|
||||||
|
|
||||||
@ -894,9 +894,10 @@ class HunyuanVideoSampler(Inference):
|
|||||||
pixel_value_bg = input_video.unsqueeze(0)
|
pixel_value_bg = input_video.unsqueeze(0)
|
||||||
pixel_value_mask = torch.zeros_like(input_video).unsqueeze(0)
|
pixel_value_mask = torch.zeros_like(input_video).unsqueeze(0)
|
||||||
if input_frames != None:
|
if input_frames != None:
|
||||||
pixel_value_video_bg = input_frames.permute(-1,0,1,2).unsqueeze(0).float()
|
pixel_value_video_bg = input_frames.unsqueeze(0) #.permute(-1,0,1,2).unsqueeze(0).float()
|
||||||
pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float()
|
# pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.)
|
||||||
pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.)
|
# pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float()
|
||||||
|
pixel_value_video_mask = input_masks.repeat(3,1,1,1).unsqueeze(0)
|
||||||
if input_video != None:
|
if input_video != None:
|
||||||
pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2)
|
pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2)
|
||||||
pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2)
|
pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2)
|
||||||
@ -908,10 +909,11 @@ class HunyuanVideoSampler(Inference):
|
|||||||
if pixel_value_bg.shape[2] < frame_num:
|
if pixel_value_bg.shape[2] < frame_num:
|
||||||
padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] + list(pixel_value_bg.shape[3:])
|
padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] + list(pixel_value_bg.shape[3:])
|
||||||
pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2)
|
pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2)
|
||||||
pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
|
# pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
|
||||||
|
pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 1, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
|
||||||
|
|
||||||
bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample()
|
bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample()
|
||||||
pixel_value_mask = pixel_value_mask.div_(127.5).add_(-1.)
|
pixel_value_mask = pixel_value_mask.mul_(2).add_(-1.) # unmasked pixels is -1 (no 0 as usual) and masked is 1
|
||||||
mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample()
|
mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample()
|
||||||
bg_latents = torch.cat([bg_latents, mask_latents], dim=1)
|
bg_latents = torch.cat([bg_latents, mask_latents], dim=1)
|
||||||
bg_latents.mul_(self.vae.config.scaling_factor)
|
bg_latents.mul_(self.vae.config.scaling_factor)
|
||||||
|
|||||||
@ -35,6 +35,8 @@ class family_handler():
|
|||||||
"selection": ["", "A", "NA", "XA", "XNA"],
|
"selection": ["", "A", "NA", "XA", "XNA"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extra_model_def["extra_control_frames"] = 1
|
||||||
|
extra_model_def["dont_cat_preguide"]= True
|
||||||
return extra_model_def
|
return extra_model_def
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -17,7 +17,7 @@ class family_handler():
|
|||||||
("Default", "default"),
|
("Default", "default"),
|
||||||
("Lightning", "lightning")],
|
("Lightning", "lightning")],
|
||||||
"guidance_max_phases" : 1,
|
"guidance_max_phases" : 1,
|
||||||
"lock_image_refs_ratios": True,
|
"fit_into_canvas_image_refs": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
if base_model_type in ["qwen_image_edit_20B"]:
|
if base_model_type in ["qwen_image_edit_20B"]:
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
|||||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||||
from .pipeline_qwenimage import QwenImagePipeline
|
from .pipeline_qwenimage import QwenImagePipeline
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from shared.utils.utils import calculate_new_dimensions
|
from shared.utils.utils import calculate_new_dimensions, convert_tensor_to_image
|
||||||
|
|
||||||
def stitch_images(img1, img2):
|
def stitch_images(img1, img2):
|
||||||
# Resize img2 to match img1's height
|
# Resize img2 to match img1's height
|
||||||
@ -103,8 +103,8 @@ class model_factory():
|
|||||||
n_prompt = None,
|
n_prompt = None,
|
||||||
sampling_steps: int = 20,
|
sampling_steps: int = 20,
|
||||||
input_ref_images = None,
|
input_ref_images = None,
|
||||||
image_guide= None,
|
input_frames= None,
|
||||||
image_mask= None,
|
input_masks= None,
|
||||||
width= 832,
|
width= 832,
|
||||||
height=480,
|
height=480,
|
||||||
guide_scale: float = 4,
|
guide_scale: float = 4,
|
||||||
@ -179,8 +179,10 @@ class model_factory():
|
|||||||
|
|
||||||
if n_prompt is None or len(n_prompt) == 0:
|
if n_prompt is None or len(n_prompt) == 0:
|
||||||
n_prompt= "text, watermark, copyright, blurry, low resolution"
|
n_prompt= "text, watermark, copyright, blurry, low resolution"
|
||||||
if image_guide is not None:
|
|
||||||
input_ref_images = [image_guide]
|
image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True)
|
||||||
|
if input_frames is not None:
|
||||||
|
input_ref_images = [convert_tensor_to_image(input_frames) ]
|
||||||
elif input_ref_images is not None:
|
elif input_ref_images is not None:
|
||||||
# image stiching method
|
# image stiching method
|
||||||
stiched = input_ref_images[0]
|
stiched = input_ref_images[0]
|
||||||
|
|||||||
143
models/wan/animate/animate_utils.py
Normal file
143
models/wan/animate/animate_utils.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
|
import torch
|
||||||
|
import numbers
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
|
|
||||||
|
def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"):
|
||||||
|
target_modules = []
|
||||||
|
for name, module in transformer.named_modules():
|
||||||
|
if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear):
|
||||||
|
target_modules.append(name)
|
||||||
|
|
||||||
|
transformer_lora_config = LoraConfig(
|
||||||
|
r=rank,
|
||||||
|
lora_alpha=alpha,
|
||||||
|
init_lora_weights=init_lora_weights,
|
||||||
|
target_modules=target_modules,
|
||||||
|
)
|
||||||
|
return transformer_lora_config
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TensorList(object):
|
||||||
|
|
||||||
|
def __init__(self, tensors):
|
||||||
|
"""
|
||||||
|
tensors: a list of torch.Tensor objects. No need to have uniform shape.
|
||||||
|
"""
|
||||||
|
assert isinstance(tensors, (list, tuple))
|
||||||
|
assert all(isinstance(u, torch.Tensor) for u in tensors)
|
||||||
|
assert len(set([u.ndim for u in tensors])) == 1
|
||||||
|
assert len(set([u.dtype for u in tensors])) == 1
|
||||||
|
assert len(set([u.device for u in tensors])) == 1
|
||||||
|
self.tensors = tensors
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
return TensorList([u.to(*args, **kwargs) for u in self.tensors])
|
||||||
|
|
||||||
|
def size(self, dim):
|
||||||
|
assert dim == 0, 'only support get the 0th size'
|
||||||
|
return len(self.tensors)
|
||||||
|
|
||||||
|
def pow(self, *args, **kwargs):
|
||||||
|
return TensorList([u.pow(*args, **kwargs) for u in self.tensors])
|
||||||
|
|
||||||
|
def squeeze(self, dim):
|
||||||
|
assert dim != 0
|
||||||
|
if dim > 0:
|
||||||
|
dim -= 1
|
||||||
|
return TensorList([u.squeeze(dim) for u in self.tensors])
|
||||||
|
|
||||||
|
def type(self, *args, **kwargs):
|
||||||
|
return TensorList([u.type(*args, **kwargs) for u in self.tensors])
|
||||||
|
|
||||||
|
def type_as(self, other):
|
||||||
|
assert isinstance(other, (torch.Tensor, TensorList))
|
||||||
|
if isinstance(other, torch.Tensor):
|
||||||
|
return TensorList([u.type_as(other) for u in self.tensors])
|
||||||
|
else:
|
||||||
|
return TensorList([u.type(other.dtype) for u in self.tensors])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self.tensors[0].dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.tensors[0].device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ndim(self):
|
||||||
|
return 1 + self.tensors[0].ndim
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.tensors[index]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.tensors)
|
||||||
|
|
||||||
|
def __add__(self, other):
|
||||||
|
return self._apply(other, lambda u, v: u + v)
|
||||||
|
|
||||||
|
def __radd__(self, other):
|
||||||
|
return self._apply(other, lambda u, v: v + u)
|
||||||
|
|
||||||
|
def __sub__(self, other):
|
||||||
|
return self._apply(other, lambda u, v: u - v)
|
||||||
|
|
||||||
|
def __rsub__(self, other):
|
||||||
|
return self._apply(other, lambda u, v: v - u)
|
||||||
|
|
||||||
|
def __mul__(self, other):
|
||||||
|
return self._apply(other, lambda u, v: u * v)
|
||||||
|
|
||||||
|
def __rmul__(self, other):
|
||||||
|
return self._apply(other, lambda u, v: v * u)
|
||||||
|
|
||||||
|
def __floordiv__(self, other):
|
||||||
|
return self._apply(other, lambda u, v: u // v)
|
||||||
|
|
||||||
|
def __truediv__(self, other):
|
||||||
|
return self._apply(other, lambda u, v: u / v)
|
||||||
|
|
||||||
|
def __rfloordiv__(self, other):
|
||||||
|
return self._apply(other, lambda u, v: v // u)
|
||||||
|
|
||||||
|
def __rtruediv__(self, other):
|
||||||
|
return self._apply(other, lambda u, v: v / u)
|
||||||
|
|
||||||
|
def __pow__(self, other):
|
||||||
|
return self._apply(other, lambda u, v: u ** v)
|
||||||
|
|
||||||
|
def __rpow__(self, other):
|
||||||
|
return self._apply(other, lambda u, v: v ** u)
|
||||||
|
|
||||||
|
def __neg__(self):
|
||||||
|
return TensorList([-u for u in self.tensors])
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for tensor in self.tensors:
|
||||||
|
yield tensor
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'TensorList: \n' + repr(self.tensors)
|
||||||
|
|
||||||
|
def _apply(self, other, op):
|
||||||
|
if isinstance(other, (list, tuple, TensorList)) or (
|
||||||
|
isinstance(other, torch.Tensor) and (
|
||||||
|
other.numel() > 1 or other.ndim > 1
|
||||||
|
)
|
||||||
|
):
|
||||||
|
assert len(other) == len(self.tensors)
|
||||||
|
return TensorList([op(u, v) for u, v in zip(self.tensors, other)])
|
||||||
|
elif isinstance(other, numbers.Number) or (
|
||||||
|
isinstance(other, torch.Tensor) and (
|
||||||
|
other.numel() == 1 and other.ndim <= 1
|
||||||
|
)
|
||||||
|
):
|
||||||
|
return TensorList([op(u, other) for u in self.tensors])
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f'unsupported operand for *: "TensorList" and "{type(other)}"'
|
||||||
|
)
|
||||||
382
models/wan/animate/face_blocks.py
Normal file
382
models/wan/animate/face_blocks.py
Normal file
@ -0,0 +1,382 @@
|
|||||||
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
from einops import rearrange
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
from shared.attention import pay_attention
|
||||||
|
|
||||||
|
MEMORY_LAYOUT = {
|
||||||
|
"flash": (
|
||||||
|
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
||||||
|
lambda x: x,
|
||||||
|
),
|
||||||
|
"torch": (
|
||||||
|
lambda x: x.transpose(1, 2),
|
||||||
|
lambda x: x.transpose(1, 2),
|
||||||
|
),
|
||||||
|
"vanilla": (
|
||||||
|
lambda x: x.transpose(1, 2),
|
||||||
|
lambda x: x.transpose(1, 2),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
mode="torch",
|
||||||
|
drop_rate=0,
|
||||||
|
attn_mask=None,
|
||||||
|
causal=False,
|
||||||
|
max_seqlen_q=None,
|
||||||
|
batch_size=1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Perform QKV self attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
||||||
|
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
||||||
|
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
||||||
|
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
||||||
|
drop_rate (float): Dropout rate in attention map. (default: 0)
|
||||||
|
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
||||||
|
(default: None)
|
||||||
|
causal (bool): Whether to use causal attention. (default: False)
|
||||||
|
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
||||||
|
used to index into q.
|
||||||
|
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
||||||
|
used to index into kv.
|
||||||
|
max_seqlen_q (int): The maximum sequence length in the batch of q.
|
||||||
|
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
||||||
|
"""
|
||||||
|
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
||||||
|
|
||||||
|
if mode == "torch":
|
||||||
|
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
||||||
|
attn_mask = attn_mask.to(q.dtype)
|
||||||
|
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
|
||||||
|
|
||||||
|
elif mode == "flash":
|
||||||
|
x = flash_attn_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
)
|
||||||
|
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
|
||||||
|
elif mode == "vanilla":
|
||||||
|
scale_factor = 1 / math.sqrt(q.size(-1))
|
||||||
|
|
||||||
|
b, a, s, _ = q.shape
|
||||||
|
s1 = k.size(2)
|
||||||
|
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
||||||
|
if causal:
|
||||||
|
# Only applied to self attention
|
||||||
|
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
|
||||||
|
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
|
||||||
|
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||||
|
attn_bias.to(q.dtype)
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
if attn_mask.dtype == torch.bool:
|
||||||
|
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
||||||
|
else:
|
||||||
|
attn_bias += attn_mask
|
||||||
|
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
||||||
|
attn += attn_bias
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = torch.dropout(attn, p=drop_rate, train=True)
|
||||||
|
x = attn @ v
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
||||||
|
|
||||||
|
x = post_attn_layout(x)
|
||||||
|
b, s, a, d = x.shape
|
||||||
|
out = x.reshape(b, s, -1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv1d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
padding = (kernel_size - 1, 0) # T
|
||||||
|
self.time_causal_padding = padding
|
||||||
|
|
||||||
|
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FaceEncoder(nn.Module):
|
||||||
|
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
|
||||||
|
factory_kwargs = {"dtype": dtype, "device": device}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
|
||||||
|
self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
|
||||||
|
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
|
||||||
|
|
||||||
|
self.out_proj = nn.Linear(1024, hidden_dim)
|
||||||
|
self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
|
||||||
|
self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
|
||||||
|
self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
|
||||||
|
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
x = rearrange(x, "b t c -> b c t")
|
||||||
|
b, c, t = x.shape
|
||||||
|
|
||||||
|
x = self.conv1_local(x)
|
||||||
|
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
|
||||||
|
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, "b t c -> b c t")
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = rearrange(x, "b c t -> b t c")
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, "b t c -> b c t")
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = rearrange(x, "b c t -> b t c")
|
||||||
|
x = self.norm3(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.out_proj(x)
|
||||||
|
x = rearrange(x, "(b n) t c -> b t n c", b=b)
|
||||||
|
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
|
||||||
|
x = torch.cat([x, padding], dim=-2)
|
||||||
|
x_local = x.clone()
|
||||||
|
|
||||||
|
return x_local
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
elementwise_affine=True,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the RMSNorm normalization layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): The dimension of the input tensor.
|
||||||
|
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
eps (float): A small value added to the denominator for numerical stability.
|
||||||
|
weight (nn.Parameter): Learnable scaling parameter.
|
||||||
|
|
||||||
|
"""
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
if elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
"""
|
||||||
|
Apply the RMSNorm normalization to the input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The normalized tensor.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass through the RMSNorm layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The output tensor after applying RMSNorm.
|
||||||
|
|
||||||
|
"""
|
||||||
|
output = self._norm(x.float()).type_as(x)
|
||||||
|
if hasattr(self, "weight"):
|
||||||
|
output = output * self.weight
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def get_norm_layer(norm_layer):
|
||||||
|
"""
|
||||||
|
Get the normalization layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
norm_layer (str): The type of normalization layer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
norm_layer (nn.Module): The normalization layer.
|
||||||
|
"""
|
||||||
|
if norm_layer == "layer":
|
||||||
|
return nn.LayerNorm
|
||||||
|
elif norm_layer == "rms":
|
||||||
|
return RMSNorm
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
||||||
|
|
||||||
|
|
||||||
|
class FaceAdapter(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_dim: int,
|
||||||
|
heads_num: int,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
qk_norm_type: str = "rms",
|
||||||
|
num_adapter_layers: int = 1,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
factory_kwargs = {"dtype": dtype, "device": device}
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_dim
|
||||||
|
self.heads_num = heads_num
|
||||||
|
self.fuser_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FaceBlock(
|
||||||
|
self.hidden_size,
|
||||||
|
self.heads_num,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
qk_norm_type=qk_norm_type,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
for _ in range(num_adapter_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
motion_embed: torch.Tensor,
|
||||||
|
idx: int,
|
||||||
|
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||||
|
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FaceBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
heads_num: int,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
qk_norm_type: str = "rms",
|
||||||
|
qk_scale: float = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.deterministic = False
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.heads_num = heads_num
|
||||||
|
head_dim = hidden_size // heads_num
|
||||||
|
self.scale = qk_scale or head_dim**-0.5
|
||||||
|
|
||||||
|
self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
|
||||||
|
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
||||||
|
|
||||||
|
self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
||||||
|
|
||||||
|
qk_norm_layer = get_norm_layer(qk_norm_type)
|
||||||
|
self.q_norm = (
|
||||||
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
||||||
|
)
|
||||||
|
self.k_norm = (
|
||||||
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
|
||||||
|
self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
motion_vec: torch.Tensor,
|
||||||
|
motion_mask: Optional[torch.Tensor] = None,
|
||||||
|
use_context_parallel=False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
B, T, N, C = motion_vec.shape
|
||||||
|
T_comp = T
|
||||||
|
|
||||||
|
x_motion = self.pre_norm_motion(motion_vec)
|
||||||
|
x_feat = self.pre_norm_feat(x)
|
||||||
|
|
||||||
|
kv = self.linear1_kv(x_motion)
|
||||||
|
q = self.linear1_q(x_feat)
|
||||||
|
|
||||||
|
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
|
||||||
|
q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
|
||||||
|
|
||||||
|
# Apply QK-Norm if needed.
|
||||||
|
q = self.q_norm(q).to(v)
|
||||||
|
k = self.k_norm(k).to(v)
|
||||||
|
|
||||||
|
k = rearrange(k, "B L N H D -> (B L) N H D")
|
||||||
|
v = rearrange(v, "B L N H D -> (B L) N H D")
|
||||||
|
|
||||||
|
if use_context_parallel:
|
||||||
|
q = gather_forward(q, dim=1)
|
||||||
|
|
||||||
|
q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)
|
||||||
|
# Compute attention.
|
||||||
|
# Size([batches, tokens, heads, head_features])
|
||||||
|
qkv_list = [q, k, v]
|
||||||
|
del q,k,v
|
||||||
|
attn = pay_attention(qkv_list)
|
||||||
|
# attn = attention(
|
||||||
|
# q,
|
||||||
|
# k,
|
||||||
|
# v,
|
||||||
|
# max_seqlen_q=q.shape[1],
|
||||||
|
# batch_size=q.shape[0],
|
||||||
|
# )
|
||||||
|
|
||||||
|
attn = attn.reshape(*attn.shape[:2], -1)
|
||||||
|
attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
|
||||||
|
# if use_context_parallel:
|
||||||
|
# attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()]
|
||||||
|
|
||||||
|
output = self.linear2(attn)
|
||||||
|
|
||||||
|
if motion_mask is not None:
|
||||||
|
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
|
||||||
|
|
||||||
|
return output
|
||||||
31
models/wan/animate/model_animate.py
Normal file
31
models/wan/animate/model_animate.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
|
import math
|
||||||
|
import types
|
||||||
|
from copy import deepcopy
|
||||||
|
from einops import rearrange
|
||||||
|
from typing import List
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.cuda.amp as amp
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
|
||||||
|
pose_latents = self.pose_patch_embedding(pose_latents)
|
||||||
|
x[:, :, 1:] += pose_latents
|
||||||
|
|
||||||
|
b,c,T,h,w = face_pixel_values.shape
|
||||||
|
face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
|
||||||
|
encode_bs = 8
|
||||||
|
face_pixel_values_tmp = []
|
||||||
|
for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
|
||||||
|
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
|
||||||
|
|
||||||
|
motion_vec = torch.cat(face_pixel_values_tmp)
|
||||||
|
|
||||||
|
motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
|
||||||
|
motion_vec = self.face_encoder(motion_vec)
|
||||||
|
|
||||||
|
B, L, H, C = motion_vec.shape
|
||||||
|
pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
|
||||||
|
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
|
||||||
|
return x, motion_vec
|
||||||
308
models/wan/animate/motion_encoder.py
Normal file
308
models/wan/animate/motion_encoder.py
Normal file
@ -0,0 +1,308 @@
|
|||||||
|
# Modified from ``https://github.com/wyhsirius/LIA``
|
||||||
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
def custom_qr(input_tensor):
|
||||||
|
original_dtype = input_tensor.dtype
|
||||||
|
if original_dtype in [torch.bfloat16, torch.float16]:
|
||||||
|
q, r = torch.linalg.qr(input_tensor.to(torch.float32))
|
||||||
|
return q.to(original_dtype), r.to(original_dtype)
|
||||||
|
return torch.linalg.qr(input_tensor)
|
||||||
|
|
||||||
|
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
||||||
|
return F.leaky_relu(input + bias, negative_slope) * scale
|
||||||
|
|
||||||
|
|
||||||
|
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
|
||||||
|
_, minor, in_h, in_w = input.shape
|
||||||
|
kernel_h, kernel_w = kernel.shape
|
||||||
|
|
||||||
|
out = input.view(-1, minor, in_h, 1, in_w, 1)
|
||||||
|
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
|
||||||
|
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
|
||||||
|
|
||||||
|
out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
||||||
|
out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
|
||||||
|
max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
|
||||||
|
|
||||||
|
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
||||||
|
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||||
|
out = F.conv2d(out, w)
|
||||||
|
out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||||
|
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
|
||||||
|
return out[:, :, ::down_y, ::down_x]
|
||||||
|
|
||||||
|
|
||||||
|
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||||
|
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
||||||
|
|
||||||
|
|
||||||
|
def make_kernel(k):
|
||||||
|
k = torch.tensor(k, dtype=torch.float32)
|
||||||
|
if k.ndim == 1:
|
||||||
|
k = k[None, :] * k[:, None]
|
||||||
|
k /= k.sum()
|
||||||
|
return k
|
||||||
|
|
||||||
|
|
||||||
|
class FusedLeakyReLU(nn.Module):
|
||||||
|
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
||||||
|
super().__init__()
|
||||||
|
self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
||||||
|
self.negative_slope = negative_slope
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Blur(nn.Module):
|
||||||
|
def __init__(self, kernel, pad, upsample_factor=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
kernel = make_kernel(kernel)
|
||||||
|
|
||||||
|
if upsample_factor > 1:
|
||||||
|
kernel = kernel * (upsample_factor ** 2)
|
||||||
|
|
||||||
|
self.register_buffer('kernel', kernel)
|
||||||
|
|
||||||
|
self.pad = pad
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return upfirdn2d(input, self.kernel, pad=self.pad)
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledLeakyReLU(nn.Module):
|
||||||
|
def __init__(self, negative_slope=0.2):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.negative_slope = negative_slope
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return F.leaky_relu(input, negative_slope=self.negative_slope)
|
||||||
|
|
||||||
|
|
||||||
|
class EqualConv2d(nn.Module):
|
||||||
|
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
|
||||||
|
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
||||||
|
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.zeros(out_channel))
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
||||||
|
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EqualLinear(nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
||||||
|
self.lr_mul = lr_mul
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
if self.activation:
|
||||||
|
out = F.linear(input, self.weight * self.scale)
|
||||||
|
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
||||||
|
else:
|
||||||
|
out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
|
||||||
|
|
||||||
|
|
||||||
|
class ConvLayer(nn.Sequential):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channel,
|
||||||
|
out_channel,
|
||||||
|
kernel_size,
|
||||||
|
downsample=False,
|
||||||
|
blur_kernel=[1, 3, 3, 1],
|
||||||
|
bias=True,
|
||||||
|
activate=True,
|
||||||
|
):
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
if downsample:
|
||||||
|
factor = 2
|
||||||
|
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||||
|
pad0 = (p + 1) // 2
|
||||||
|
pad1 = p // 2
|
||||||
|
|
||||||
|
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
||||||
|
|
||||||
|
stride = 2
|
||||||
|
self.padding = 0
|
||||||
|
|
||||||
|
else:
|
||||||
|
stride = 1
|
||||||
|
self.padding = kernel_size // 2
|
||||||
|
|
||||||
|
layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
|
||||||
|
bias=bias and not activate))
|
||||||
|
|
||||||
|
if activate:
|
||||||
|
if bias:
|
||||||
|
layers.append(FusedLeakyReLU(out_channel))
|
||||||
|
else:
|
||||||
|
layers.append(ScaledLeakyReLU(0.2))
|
||||||
|
|
||||||
|
super().__init__(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
||||||
|
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
||||||
|
|
||||||
|
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
out = self.conv1(input)
|
||||||
|
out = self.conv2(out)
|
||||||
|
|
||||||
|
skip = self.skip(input)
|
||||||
|
out = (out + skip) / math.sqrt(2)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderApp(nn.Module):
|
||||||
|
def __init__(self, size, w_dim=512):
|
||||||
|
super(EncoderApp, self).__init__()
|
||||||
|
|
||||||
|
channels = {
|
||||||
|
4: 512,
|
||||||
|
8: 512,
|
||||||
|
16: 512,
|
||||||
|
32: 512,
|
||||||
|
64: 256,
|
||||||
|
128: 128,
|
||||||
|
256: 64,
|
||||||
|
512: 32,
|
||||||
|
1024: 16
|
||||||
|
}
|
||||||
|
|
||||||
|
self.w_dim = w_dim
|
||||||
|
log_size = int(math.log(size, 2))
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList()
|
||||||
|
self.convs.append(ConvLayer(3, channels[size], 1))
|
||||||
|
|
||||||
|
in_channel = channels[size]
|
||||||
|
for i in range(log_size, 2, -1):
|
||||||
|
out_channel = channels[2 ** (i - 1)]
|
||||||
|
self.convs.append(ResBlock(in_channel, out_channel))
|
||||||
|
in_channel = out_channel
|
||||||
|
|
||||||
|
self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
res = []
|
||||||
|
h = x
|
||||||
|
for conv in self.convs:
|
||||||
|
h = conv(h)
|
||||||
|
res.append(h)
|
||||||
|
|
||||||
|
return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(self, size, dim=512, dim_motion=20):
|
||||||
|
super(Encoder, self).__init__()
|
||||||
|
|
||||||
|
# appearance netmork
|
||||||
|
self.net_app = EncoderApp(size, dim)
|
||||||
|
|
||||||
|
# motion network
|
||||||
|
fc = [EqualLinear(dim, dim)]
|
||||||
|
for i in range(3):
|
||||||
|
fc.append(EqualLinear(dim, dim))
|
||||||
|
|
||||||
|
fc.append(EqualLinear(dim, dim_motion))
|
||||||
|
self.fc = nn.Sequential(*fc)
|
||||||
|
|
||||||
|
def enc_app(self, x):
|
||||||
|
h_source = self.net_app(x)
|
||||||
|
return h_source
|
||||||
|
|
||||||
|
def enc_motion(self, x):
|
||||||
|
h, _ = self.net_app(x)
|
||||||
|
h_motion = self.fc(h)
|
||||||
|
return h_motion
|
||||||
|
|
||||||
|
|
||||||
|
class Direction(nn.Module):
|
||||||
|
def __init__(self, motion_dim):
|
||||||
|
super(Direction, self).__init__()
|
||||||
|
self.weight = nn.Parameter(torch.randn(512, motion_dim))
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
weight = self.weight + 1e-8
|
||||||
|
Q, R = custom_qr(weight)
|
||||||
|
if input is None:
|
||||||
|
return Q
|
||||||
|
else:
|
||||||
|
input_diag = torch.diag_embed(input) # alpha, diagonal matrix
|
||||||
|
out = torch.matmul(input_diag, Q.T)
|
||||||
|
out = torch.sum(out, dim=1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Synthesis(nn.Module):
|
||||||
|
def __init__(self, motion_dim):
|
||||||
|
super(Synthesis, self).__init__()
|
||||||
|
self.direction = Direction(motion_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class Generator(nn.Module):
|
||||||
|
def __init__(self, size, style_dim=512, motion_dim=20):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.enc = Encoder(size, style_dim, motion_dim)
|
||||||
|
self.dec = Synthesis(motion_dim)
|
||||||
|
|
||||||
|
def get_motion(self, img):
|
||||||
|
#motion_feat = self.enc.enc_motion(img)
|
||||||
|
# motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
|
||||||
|
with torch.cuda.amp.autocast(dtype=torch.float32):
|
||||||
|
motion_feat = self.enc.enc_motion(img)
|
||||||
|
motion = self.dec.direction(motion_feat)
|
||||||
|
return motion
|
||||||
@ -203,10 +203,7 @@ class WanAny2V:
|
|||||||
self.use_timestep_transform = True
|
self.use_timestep_transform = True
|
||||||
|
|
||||||
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = None):
|
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = None):
|
||||||
if ref_images is None:
|
ref_images = [ref_images] * len(frames)
|
||||||
ref_images = [None] * len(frames)
|
|
||||||
else:
|
|
||||||
assert len(frames) == len(ref_images)
|
|
||||||
|
|
||||||
if masks is None:
|
if masks is None:
|
||||||
latents = self.vae.encode(frames, tile_size = tile_size)
|
latents = self.vae.encode(frames, tile_size = tile_size)
|
||||||
@ -238,11 +235,7 @@ class WanAny2V:
|
|||||||
return cat_latents
|
return cat_latents
|
||||||
|
|
||||||
def vace_encode_masks(self, masks, ref_images=None):
|
def vace_encode_masks(self, masks, ref_images=None):
|
||||||
if ref_images is None:
|
ref_images = [ref_images] * len(masks)
|
||||||
ref_images = [None] * len(masks)
|
|
||||||
else:
|
|
||||||
assert len(masks) == len(ref_images)
|
|
||||||
|
|
||||||
result_masks = []
|
result_masks = []
|
||||||
for mask, refs in zip(masks, ref_images):
|
for mask, refs in zip(masks, ref_images):
|
||||||
c, depth, height, width = mask.shape
|
c, depth, height, width = mask.shape
|
||||||
@ -270,79 +263,6 @@ class WanAny2V:
|
|||||||
result_masks.append(mask)
|
result_masks.append(mask)
|
||||||
return result_masks
|
return result_masks
|
||||||
|
|
||||||
def vace_latent(self, z, m):
|
|
||||||
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
|
|
||||||
image_sizes = []
|
|
||||||
trim_video_guide = len(keep_video_guide_frames)
|
|
||||||
def conv_tensor(t, device):
|
|
||||||
return t.float().div_(127.5).add_(-1).permute(3, 0, 1, 2).to(device)
|
|
||||||
|
|
||||||
for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
|
|
||||||
prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
|
|
||||||
num_frames = total_frames - prepend_count
|
|
||||||
num_frames = min(num_frames, trim_video_guide) if trim_video_guide > 0 and sub_src_video != None else num_frames
|
|
||||||
if sub_src_mask is not None and sub_src_video is not None:
|
|
||||||
src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
|
|
||||||
src_mask[i] = conv_tensor(sub_src_mask[:num_frames], device)
|
|
||||||
# src_video is [-1, 1] (at this function output), 0 = inpainting area (in fact 127 in [0, 255])
|
|
||||||
# src_mask is [-1, 1] (at this function output), 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
|
|
||||||
if prepend_count > 0:
|
|
||||||
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
|
|
||||||
src_mask[i] = torch.cat( [torch.full_like(sub_pre_src_video, -1.0), src_mask[i]] ,1)
|
|
||||||
src_video_shape = src_video[i].shape
|
|
||||||
if src_video_shape[1] != total_frames:
|
|
||||||
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
|
||||||
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
|
||||||
src_mask[i] = torch.clamp((src_mask[i][:, :, :, :] + 1) / 2, min=0, max=1)
|
|
||||||
image_sizes.append(src_video[i].shape[2:])
|
|
||||||
elif sub_src_video is None:
|
|
||||||
if prepend_count > 0:
|
|
||||||
src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1)
|
|
||||||
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1)
|
|
||||||
else:
|
|
||||||
src_video[i] = torch.zeros((3, total_frames, image_size[0], image_size[1]), device=device)
|
|
||||||
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
|
||||||
image_sizes.append(image_size)
|
|
||||||
else:
|
|
||||||
src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
|
|
||||||
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
|
||||||
if prepend_count > 0:
|
|
||||||
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
|
|
||||||
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
|
|
||||||
src_video_shape = src_video[i].shape
|
|
||||||
if src_video_shape[1] != total_frames:
|
|
||||||
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
|
||||||
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
|
||||||
image_sizes.append(src_video[i].shape[2:])
|
|
||||||
for k, keep in enumerate(keep_video_guide_frames):
|
|
||||||
if not keep:
|
|
||||||
pos = prepend_count + k
|
|
||||||
src_video[i][:, pos:pos+1] = 0
|
|
||||||
src_mask[i][:, pos:pos+1] = 1
|
|
||||||
|
|
||||||
for k, frame in enumerate(inject_frames):
|
|
||||||
if frame != None:
|
|
||||||
pos = prepend_count + k
|
|
||||||
src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True)
|
|
||||||
|
|
||||||
|
|
||||||
self.background_mask = None
|
|
||||||
for i, ref_images in enumerate(src_ref_images):
|
|
||||||
if ref_images is not None:
|
|
||||||
image_size = image_sizes[i]
|
|
||||||
for j, ref_img in enumerate(ref_images):
|
|
||||||
if ref_img is not None and not torch.is_tensor(ref_img):
|
|
||||||
if j==0 and any_background_ref:
|
|
||||||
if self.background_mask == None: self.background_mask = [None] * len(src_ref_images)
|
|
||||||
src_ref_images[i][j], self.background_mask[i] = fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True)
|
|
||||||
else:
|
|
||||||
src_ref_images[i][j], _ = fit_image_into_canvas(ref_img, image_size, 1, device)
|
|
||||||
if self.background_mask != None:
|
|
||||||
self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref
|
|
||||||
return src_video, src_mask, src_ref_images
|
|
||||||
|
|
||||||
def get_vae_latents(self, ref_images, device, tile_size= 0):
|
def get_vae_latents(self, ref_images, device, tile_size= 0):
|
||||||
ref_vae_latents = []
|
ref_vae_latents = []
|
||||||
@ -369,7 +289,9 @@ class WanAny2V:
|
|||||||
def generate(self,
|
def generate(self,
|
||||||
input_prompt,
|
input_prompt,
|
||||||
input_frames= None,
|
input_frames= None,
|
||||||
|
input_frames2= None,
|
||||||
input_masks = None,
|
input_masks = None,
|
||||||
|
input_masks2 = None,
|
||||||
input_ref_images = None,
|
input_ref_images = None,
|
||||||
input_ref_masks = None,
|
input_ref_masks = None,
|
||||||
input_faces = None,
|
input_faces = None,
|
||||||
@ -615,21 +537,22 @@ class WanAny2V:
|
|||||||
pose_pixels = input_frames * input_masks
|
pose_pixels = input_frames * input_masks
|
||||||
input_masks = 1. - input_masks
|
input_masks = 1. - input_masks
|
||||||
pose_pixels -= input_masks
|
pose_pixels -= input_masks
|
||||||
save_video(pose_pixels, "pose.mp4")
|
|
||||||
pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0)
|
pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0)
|
||||||
input_frames = input_frames * input_masks
|
input_frames = input_frames * input_masks
|
||||||
if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames
|
if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames
|
||||||
if prefix_frames_count > 0:
|
if prefix_frames_count > 0:
|
||||||
input_frames[:, :prefix_frames_count] = input_video
|
input_frames[:, :prefix_frames_count] = input_video
|
||||||
input_masks[:, :prefix_frames_count] = 1
|
input_masks[:, :prefix_frames_count] = 1
|
||||||
save_video(input_frames, "input_frames.mp4")
|
# save_video(pose_pixels, "pose.mp4")
|
||||||
save_video(input_masks, "input_masks.mp4", value_range=(0,1))
|
# save_video(input_frames, "input_frames.mp4")
|
||||||
|
# save_video(input_masks, "input_masks.mp4", value_range=(0,1))
|
||||||
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
|
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
|
||||||
msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device)
|
msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device)
|
||||||
msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device)
|
msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device)
|
||||||
msk = torch.concat([msk_ref, msk_control], dim=1)
|
msk = torch.concat([msk_ref, msk_control], dim=1)
|
||||||
clip_image_start = image_ref = convert_image_to_tensor(input_ref_images[0]).to(self.device)
|
image_ref = input_ref_images[0].to(self.device)
|
||||||
lat_y = torch.concat(self.vae.encode([image_ref.unsqueeze(1).to(self.device), input_frames.to(self.device)], VAE_tile_size), dim=1)
|
clip_image_start = image_ref.squeeze(1)
|
||||||
|
lat_y = torch.concat(self.vae.encode([image_ref, input_frames.to(self.device)], VAE_tile_size), dim=1)
|
||||||
y = torch.concat([msk, lat_y])
|
y = torch.concat([msk, lat_y])
|
||||||
kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)})
|
kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)})
|
||||||
lat_y = msk = msk_control = msk_ref = pose_pixels = None
|
lat_y = msk = msk_control = msk_ref = pose_pixels = None
|
||||||
@ -701,12 +624,11 @@ class WanAny2V:
|
|||||||
|
|
||||||
# Phantom
|
# Phantom
|
||||||
if phantom:
|
if phantom:
|
||||||
input_ref_images_neg = None
|
lat_input_ref_images_neg = None
|
||||||
if input_ref_images != None: # Phantom Ref images
|
if input_ref_images is not None: # Phantom Ref images
|
||||||
input_ref_images = self.get_vae_latents(input_ref_images, self.device)
|
lat_input_ref_images = self.get_vae_latents(input_ref_images, self.device)
|
||||||
input_ref_images_neg = torch.zeros_like(input_ref_images)
|
lat_input_ref_images_neg = torch.zeros_like(lat_input_ref_images)
|
||||||
ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0
|
ref_images_count = trim_frames = lat_input_ref_images.shape[1]
|
||||||
trim_frames = input_ref_images.shape[1]
|
|
||||||
|
|
||||||
if ti2v:
|
if ti2v:
|
||||||
if input_video is None:
|
if input_video is None:
|
||||||
@ -721,25 +643,23 @@ class WanAny2V:
|
|||||||
# Vace
|
# Vace
|
||||||
if vace :
|
if vace :
|
||||||
# vace context encode
|
# vace context encode
|
||||||
input_frames = [u.to(self.device) for u in input_frames]
|
input_frames = [input_frames.to(self.device)] +([] if input_frames2 is None else [input_frames2.to(self.device)])
|
||||||
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
|
input_masks = [input_masks.to(self.device)] + ([] if input_masks2 is None else [input_masks2.to(self.device)])
|
||||||
input_masks = [u.to(self.device) for u in input_masks]
|
input_ref_images = None if input_ref_images is None else [ u.to(self.device) for u in input_ref_images]
|
||||||
|
input_ref_masks = None if input_ref_masks is None else [ None if u is None else u.to(self.device) for u in input_ref_masks]
|
||||||
ref_images_before = True
|
ref_images_before = True
|
||||||
if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask]
|
|
||||||
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
|
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
|
||||||
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
||||||
if self.background_mask != None:
|
if input_ref_masks is not None and len(input_ref_masks) > 0 and input_ref_masks[0] is not None:
|
||||||
color_reference_frame = input_ref_images[0][0].clone()
|
color_reference_frame = input_ref_images[0].clone()
|
||||||
zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size )
|
zbg = self.vace_encode_frames( input_ref_images[:1] * len(input_frames), None, masks=input_ref_masks[0], tile_size = VAE_tile_size )
|
||||||
mbg = self.vace_encode_masks(self.background_mask, None)
|
mbg = self.vace_encode_masks(input_ref_masks[:1] * len(input_frames), None)
|
||||||
for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg):
|
for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg):
|
||||||
zz0[:, 0:1] = zzbg
|
zz0[:, 0:1] = zzbg
|
||||||
mm0[:, 0:1] = mmbg
|
mm0[:, 0:1] = mmbg
|
||||||
|
zz0 = mm0 = zzbg = mmbg = None
|
||||||
self.background_mask = zz0 = mm0 = zzbg = mmbg = None
|
z = [torch.cat([zz, mm], dim=0) for zz, mm in zip(z0, m0)]
|
||||||
z = self.vace_latent(z0, m0)
|
ref_images_count = len(input_ref_images) if input_ref_images is not None and input_ref_images is not None else 0
|
||||||
|
|
||||||
ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0
|
|
||||||
context_scale = context_scale if context_scale != None else [1.0] * len(z)
|
context_scale = context_scale if context_scale != None else [1.0] * len(z)
|
||||||
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count })
|
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count })
|
||||||
if overlapped_latents != None :
|
if overlapped_latents != None :
|
||||||
@ -747,15 +667,8 @@ class WanAny2V:
|
|||||||
extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0)
|
extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0)
|
||||||
if prefix_frames_count > 0:
|
if prefix_frames_count > 0:
|
||||||
color_reference_frame = input_frames[0][:, prefix_frames_count -1:prefix_frames_count].clone()
|
color_reference_frame = input_frames[0][:, prefix_frames_count -1:prefix_frames_count].clone()
|
||||||
|
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
|
||||||
target_shape = list(z0[0].shape)
|
target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, lat_h, lat_w)
|
||||||
target_shape[0] = int(target_shape[0] / 2)
|
|
||||||
lat_h, lat_w = target_shape[-2:]
|
|
||||||
height = self.vae_stride[1] * lat_h
|
|
||||||
width = self.vae_stride[2] * lat_w
|
|
||||||
|
|
||||||
else:
|
|
||||||
target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2])
|
|
||||||
|
|
||||||
if multitalk:
|
if multitalk:
|
||||||
if audio_proj is None:
|
if audio_proj is None:
|
||||||
@ -860,7 +773,9 @@ class WanAny2V:
|
|||||||
apg_norm_threshold = 55
|
apg_norm_threshold = 55
|
||||||
text_momentumbuffer = MomentumBuffer(apg_momentum)
|
text_momentumbuffer = MomentumBuffer(apg_momentum)
|
||||||
audio_momentumbuffer = MomentumBuffer(apg_momentum)
|
audio_momentumbuffer = MomentumBuffer(apg_momentum)
|
||||||
|
input_frames = input_frames2 = input_masks =input_masks2 = input_video = input_ref_images = input_ref_masks = pre_video_frame = None
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# denoising
|
# denoising
|
||||||
trans = self.model
|
trans = self.model
|
||||||
@ -878,7 +793,7 @@ class WanAny2V:
|
|||||||
kwargs.update({"t": timestep, "current_step": start_step_no + i})
|
kwargs.update({"t": timestep, "current_step": start_step_no + i})
|
||||||
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
|
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
|
||||||
|
|
||||||
if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step:
|
if denoising_strength < 1 and i <= injection_denoising_step:
|
||||||
sigma = t / 1000
|
sigma = t / 1000
|
||||||
noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
|
noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
|
||||||
if inject_from_start:
|
if inject_from_start:
|
||||||
@ -912,8 +827,8 @@ class WanAny2V:
|
|||||||
any_guidance = guide_scale != 1
|
any_guidance = guide_scale != 1
|
||||||
if phantom:
|
if phantom:
|
||||||
gen_args = {
|
gen_args = {
|
||||||
"x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 +
|
"x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 +
|
||||||
[ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]),
|
[ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]),
|
||||||
"context": [context, context_null, context_null] ,
|
"context": [context, context_null, context_null] ,
|
||||||
}
|
}
|
||||||
elif fantasy:
|
elif fantasy:
|
||||||
|
|||||||
@ -21,6 +21,7 @@ class family_handler():
|
|||||||
extra_model_def["fps"] =fps
|
extra_model_def["fps"] =fps
|
||||||
extra_model_def["frames_minimum"] = 17
|
extra_model_def["frames_minimum"] = 17
|
||||||
extra_model_def["frames_steps"] = 20
|
extra_model_def["frames_steps"] = 20
|
||||||
|
extra_model_def["latent_size"] = 4
|
||||||
extra_model_def["sliding_window"] = True
|
extra_model_def["sliding_window"] = True
|
||||||
extra_model_def["skip_layer_guidance"] = True
|
extra_model_def["skip_layer_guidance"] = True
|
||||||
extra_model_def["tea_cache"] = True
|
extra_model_def["tea_cache"] = True
|
||||||
|
|||||||
@ -114,7 +114,6 @@ class family_handler():
|
|||||||
"tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels),
|
"tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels),
|
||||||
"mag_cache" : True,
|
"mag_cache" : True,
|
||||||
"keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
|
"keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
|
||||||
"convert_image_guide_to_video" : True,
|
|
||||||
"sample_solvers":[
|
"sample_solvers":[
|
||||||
("unipc", "unipc"),
|
("unipc", "unipc"),
|
||||||
("euler", "euler"),
|
("euler", "euler"),
|
||||||
@ -175,6 +174,8 @@ class family_handler():
|
|||||||
extra_model_def["forced_guide_mask_inputs"] = True
|
extra_model_def["forced_guide_mask_inputs"] = True
|
||||||
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)"
|
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)"
|
||||||
extra_model_def["background_ref_outpainted"] = False
|
extra_model_def["background_ref_outpainted"] = False
|
||||||
|
extra_model_def["return_image_refs_tensor"] = True
|
||||||
|
extra_model_def["guide_inpaint_color"] = 0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -196,15 +197,15 @@ class family_handler():
|
|||||||
"letters_filter": "KFI",
|
"letters_filter": "KFI",
|
||||||
}
|
}
|
||||||
|
|
||||||
extra_model_def["lock_image_refs_ratios"] = True
|
|
||||||
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames"
|
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames"
|
||||||
extra_model_def["video_guide_outpainting"] = [0,1]
|
extra_model_def["video_guide_outpainting"] = [0,1]
|
||||||
extra_model_def["pad_guide_video"] = True
|
extra_model_def["pad_guide_video"] = True
|
||||||
extra_model_def["guide_inpaint_color"] = 127.5
|
extra_model_def["guide_inpaint_color"] = 127.5
|
||||||
extra_model_def["forced_guide_mask_inputs"] = True
|
extra_model_def["forced_guide_mask_inputs"] = True
|
||||||
|
extra_model_def["return_image_refs_tensor"] = True
|
||||||
|
|
||||||
if base_model_type in ["standin"]:
|
if base_model_type in ["standin"]:
|
||||||
extra_model_def["lock_image_refs_ratios"] = True
|
extra_model_def["fit_into_canvas_image_refs"] = 0
|
||||||
extra_model_def["image_ref_choices"] = {
|
extra_model_def["image_ref_choices"] = {
|
||||||
"choices": [
|
"choices": [
|
||||||
("No Reference Image", ""),
|
("No Reference Image", ""),
|
||||||
@ -480,6 +481,7 @@ class family_handler():
|
|||||||
ui_defaults.update({
|
ui_defaults.update({
|
||||||
"video_prompt_type": "PVBXAKI",
|
"video_prompt_type": "PVBXAKI",
|
||||||
"mask_expand": 20,
|
"mask_expand": 20,
|
||||||
|
"audio_prompt_type_value": "R",
|
||||||
})
|
})
|
||||||
|
|
||||||
if text_oneframe_overlap(base_model_type):
|
if text_oneframe_overlap(base_model_type):
|
||||||
|
|||||||
@ -32,6 +32,14 @@ def seed_everything(seed: int):
|
|||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
torch.mps.manual_seed(seed)
|
torch.mps.manual_seed(seed)
|
||||||
|
|
||||||
|
def has_video_file_extension(filename):
|
||||||
|
extension = os.path.splitext(filename)[-1].lower()
|
||||||
|
return extension in [".mp4"]
|
||||||
|
|
||||||
|
def has_image_file_extension(filename):
|
||||||
|
extension = os.path.splitext(filename)[-1].lower()
|
||||||
|
return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"]
|
||||||
|
|
||||||
def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ):
|
def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ):
|
||||||
import math
|
import math
|
||||||
|
|
||||||
@ -94,7 +102,7 @@ def get_video_info(video_path):
|
|||||||
|
|
||||||
return fps, width, height, frame_count
|
return fps, width, height, frame_count
|
||||||
|
|
||||||
def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, return_PIL = True) -> torch.Tensor:
|
def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, target_fps = None, return_PIL = True) -> torch.Tensor:
|
||||||
"""Extract nth frame from video as PyTorch tensor normalized to [-1, 1]."""
|
"""Extract nth frame from video as PyTorch tensor normalized to [-1, 1]."""
|
||||||
cap = cv2.VideoCapture(file_name)
|
cap = cv2.VideoCapture(file_name)
|
||||||
|
|
||||||
@ -102,6 +110,9 @@ def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool
|
|||||||
raise ValueError(f"Cannot open video: {file_name}")
|
raise ValueError(f"Cannot open video: {file_name}")
|
||||||
|
|
||||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
fps = round(cap.get(cv2.CAP_PROP_FPS))
|
||||||
|
if target_fps is not None:
|
||||||
|
frame_no = round(target_fps * frame_no /fps)
|
||||||
|
|
||||||
# Handle out of bounds
|
# Handle out of bounds
|
||||||
if frame_no >= total_frames or frame_no < 0:
|
if frame_no >= total_frames or frame_no < 0:
|
||||||
@ -175,10 +186,15 @@ def remove_background(img, session=None):
|
|||||||
def convert_image_to_tensor(image):
|
def convert_image_to_tensor(image):
|
||||||
return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0)
|
return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0)
|
||||||
|
|
||||||
def convert_tensor_to_image(t, frame_no = 0):
|
def convert_tensor_to_image(t, frame_no = 0, mask_levels = False):
|
||||||
if len(t.shape) == 4:
|
if len(t.shape) == 4:
|
||||||
t = t[:, frame_no]
|
t = t[:, frame_no]
|
||||||
return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
|
if t.shape[0]== 1:
|
||||||
|
t = t.expand(3,-1,-1)
|
||||||
|
if mask_levels:
|
||||||
|
return Image.fromarray(t.clone().mul_(255).permute(1,2,0).to(torch.uint8).cpu().numpy())
|
||||||
|
else:
|
||||||
|
return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
|
||||||
|
|
||||||
def save_image(tensor_image, name, frame_no = -1):
|
def save_image(tensor_image, name, frame_no = -1):
|
||||||
convert_tensor_to_image(tensor_image, frame_no).save(name)
|
convert_tensor_to_image(tensor_image, frame_no).save(name)
|
||||||
@ -257,7 +273,7 @@ def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fi
|
|||||||
image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
||||||
return image, new_height, new_width
|
return image, new_height, new_width
|
||||||
|
|
||||||
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5 ):
|
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5, return_tensor = False ):
|
||||||
if rm_background:
|
if rm_background:
|
||||||
session = new_session()
|
session = new_session()
|
||||||
|
|
||||||
@ -266,7 +282,7 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
|
|||||||
for i, img in enumerate(img_list):
|
for i, img in enumerate(img_list):
|
||||||
width, height = img.size
|
width, height = img.size
|
||||||
resized_mask = None
|
resized_mask = None
|
||||||
if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2:
|
if any_background_ref == 1 and i==0 or any_background_ref == 2:
|
||||||
if outpainting_dims is not None and background_ref_outpainted:
|
if outpainting_dims is not None and background_ref_outpainted:
|
||||||
resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True)
|
resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True)
|
||||||
elif img.size != (budget_width, budget_height):
|
elif img.size != (budget_width, budget_height):
|
||||||
@ -291,7 +307,10 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
|
|||||||
if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) :
|
if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) :
|
||||||
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
||||||
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
||||||
output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
|
if return_tensor:
|
||||||
|
output_list.append(convert_image_to_tensor(resized_image).unsqueeze(1))
|
||||||
|
else:
|
||||||
|
output_list.append(resized_image)
|
||||||
output_mask_list.append(resized_mask)
|
output_mask_list.append(resized_mask)
|
||||||
return output_list, output_mask_list
|
return output_list, output_mask_list
|
||||||
|
|
||||||
@ -346,47 +365,46 @@ def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu
|
|||||||
|
|
||||||
return ref_img.to(device), canvas
|
return ref_img.to(device), canvas
|
||||||
|
|
||||||
def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, extract_guide_from_window_start = False, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None ):
|
def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None, device ="cpu"):
|
||||||
src_videos, src_masks = [], []
|
src_videos, src_masks = [], []
|
||||||
inpaint_color = guide_inpaint_color/127.5 - 1
|
inpaint_color_compressed = guide_inpaint_color/127.5 - 1
|
||||||
prepend_count = pre_video_guide.shape[1] if not extract_guide_from_window_start and pre_video_guide is not None else 0
|
prepend_count = pre_video_guide.shape[1] if pre_video_guide is not None else 0
|
||||||
for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)):
|
for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)):
|
||||||
src_video = src_mask = None
|
src_video, src_mask = cur_video_guide, cur_video_mask
|
||||||
if cur_video_guide is not None:
|
if pre_video_guide is not None:
|
||||||
src_video = cur_video_guide.permute(3, 0, 1, 2).float().div_(127.5).sub_(1.) # c, f, h, w
|
|
||||||
if cur_video_mask is not None and any_mask:
|
|
||||||
src_mask = cur_video_mask.permute(3, 0, 1, 2).float().div_(255)[0:1] # c, f, h, w
|
|
||||||
if pre_video_guide is not None and not extract_guide_from_window_start:
|
|
||||||
src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1)
|
src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1)
|
||||||
if any_mask:
|
if any_mask:
|
||||||
src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1)
|
src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1)
|
||||||
if src_video is None:
|
|
||||||
if any_guide_padding:
|
if any_guide_padding:
|
||||||
src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color, dtype = torch.float, device= "cpu")
|
if src_video is None:
|
||||||
if any_mask:
|
src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color_compressed, dtype = torch.float, device= device)
|
||||||
src_mask = torch.zeros_like(src_video[0:1])
|
elif src_video.shape[1] < current_video_length:
|
||||||
elif src_video.shape[1] < current_video_length:
|
src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color_compressed, dtype = src_video.dtype, device= src_video.device) ], dim=1)
|
||||||
if any_guide_padding:
|
elif src_video is not None:
|
||||||
src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color, dtype = src_video.dtype, device= src_video.device) ], dim=1)
|
new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1
|
||||||
if cur_video_mask is not None and any_mask:
|
src_video = src_video[:, :new_num_frames]
|
||||||
src_mask = torch.cat([src_mask, torch.full( (1, current_video_length - src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1)
|
|
||||||
|
if any_mask and src_video is not None:
|
||||||
|
if src_mask is None:
|
||||||
|
src_mask = torch.ones_like(src_video[:1])
|
||||||
|
elif src_mask.shape[1] < src_video.shape[1]:
|
||||||
|
src_mask = torch.cat([src_mask, torch.full( (1, src_video.shape[1]- src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1)
|
||||||
else:
|
else:
|
||||||
new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1
|
src_mask = src_mask[:, :src_video.shape[1]]
|
||||||
src_video = src_video[:, :new_num_frames]
|
|
||||||
if any_mask:
|
|
||||||
src_mask = src_mask[:, :new_num_frames]
|
|
||||||
|
|
||||||
for k, keep in enumerate(keep_video_guide_frames):
|
if src_video is not None :
|
||||||
if not keep:
|
for k, keep in enumerate(keep_video_guide_frames):
|
||||||
pos = prepend_count + k
|
if not keep:
|
||||||
src_video[:, pos:pos+1] = inpaint_color
|
pos = prepend_count + k
|
||||||
src_mask[:, pos:pos+1] = 1
|
src_video[:, pos:pos+1] = inpaint_color_compressed
|
||||||
|
if any_mask: src_mask[:, pos:pos+1] = 1
|
||||||
for k, frame in enumerate(inject_frames):
|
|
||||||
if frame != None:
|
|
||||||
pos = prepend_count + k
|
|
||||||
src_video[:, pos:pos+1], src_mask[:, pos:pos+1] = fit_image_into_canvas(frame, image_size, inpaint_color, device, True, outpainting_dims, return_mask= True)
|
|
||||||
|
|
||||||
|
for k, frame in enumerate(inject_frames):
|
||||||
|
if frame != None:
|
||||||
|
pos = prepend_count + k
|
||||||
|
src_video[:, pos:pos+1], msk = fit_image_into_canvas(frame, image_size, guide_inpaint_color, device, True, outpainting_dims, return_mask= any_mask)
|
||||||
|
if any_mask: src_mask[:, pos:pos+1] = msk
|
||||||
src_videos.append(src_video)
|
src_videos.append(src_video)
|
||||||
src_masks.append(src_mask)
|
src_masks.append(src_mask)
|
||||||
return src_videos, src_masks
|
return src_videos, src_masks
|
||||||
|
|||||||
396
wgp.py
396
wgp.py
@ -24,6 +24,7 @@ from shared.utils import notification_sound
|
|||||||
from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers
|
from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers
|
||||||
from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background, rgb_bw_to_rgba_mask
|
from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background, rgb_bw_to_rgba_mask
|
||||||
from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions
|
from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions
|
||||||
|
from shared.utils.utils import has_video_file_extension, has_image_file_extension
|
||||||
from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image
|
from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image
|
||||||
from shared.utils.audio_video import save_image_metadata, read_image_metadata
|
from shared.utils.audio_video import save_image_metadata, read_image_metadata
|
||||||
from shared.match_archi import match_nvidia_architecture
|
from shared.match_archi import match_nvidia_architecture
|
||||||
@ -62,7 +63,7 @@ AUTOSAVE_FILENAME = "queue.zip"
|
|||||||
PROMPT_VARS_MAX = 10
|
PROMPT_VARS_MAX = 10
|
||||||
|
|
||||||
target_mmgp_version = "3.6.0"
|
target_mmgp_version = "3.6.0"
|
||||||
WanGP_version = "8.61"
|
WanGP_version = "8.7"
|
||||||
settings_version = 2.35
|
settings_version = 2.35
|
||||||
max_source_video_frames = 3000
|
max_source_video_frames = 3000
|
||||||
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
|
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
|
||||||
@ -1942,7 +1943,8 @@ def get_model_min_frames_and_step(model_type):
|
|||||||
mode_def = get_model_def(model_type)
|
mode_def = get_model_def(model_type)
|
||||||
frames_minimum = mode_def.get("frames_minimum", 5)
|
frames_minimum = mode_def.get("frames_minimum", 5)
|
||||||
frames_steps = mode_def.get("frames_steps", 4)
|
frames_steps = mode_def.get("frames_steps", 4)
|
||||||
return frames_minimum, frames_steps
|
latent_size = mode_def.get("latent_size", frames_steps)
|
||||||
|
return frames_minimum, frames_steps, latent_size
|
||||||
|
|
||||||
def get_model_fps(model_type):
|
def get_model_fps(model_type):
|
||||||
mode_def = get_model_def(model_type)
|
mode_def = get_model_def(model_type)
|
||||||
@ -3459,7 +3461,7 @@ def select_video(state, input_file_list, event_data: gr.EventData):
|
|||||||
if len(video_other_prompts) >0 :
|
if len(video_other_prompts) >0 :
|
||||||
values += [video_other_prompts]
|
values += [video_other_prompts]
|
||||||
labels += ["Other Prompts"]
|
labels += ["Other Prompts"]
|
||||||
if len(video_outpainting) >0 and any_letters(video_image_prompt_type, "VFK"):
|
if len(video_outpainting) >0:
|
||||||
values += [video_outpainting]
|
values += [video_outpainting]
|
||||||
labels += ["Outpainting"]
|
labels += ["Outpainting"]
|
||||||
video_sample_solver = configs.get("sample_solver", "")
|
video_sample_solver = configs.get("sample_solver", "")
|
||||||
@ -3532,6 +3534,11 @@ def convert_image(image):
|
|||||||
return cast(Image, ImageOps.exif_transpose(image))
|
return cast(Image, ImageOps.exif_transpose(image))
|
||||||
|
|
||||||
def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='torch'):
|
def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='torch'):
|
||||||
|
if isinstance(video_in, str) and has_image_file_extension(video_in):
|
||||||
|
video_in = Image.open(video_in)
|
||||||
|
if isinstance(video_in, Image.Image):
|
||||||
|
return torch.from_numpy(np.array(video_in).astype(np.uint8)).unsqueeze(0)
|
||||||
|
|
||||||
from shared.utils.utils import resample
|
from shared.utils.utils import resample
|
||||||
|
|
||||||
import decord
|
import decord
|
||||||
@ -3653,19 +3660,22 @@ def get_preprocessor(process_type, inpaint_color):
|
|||||||
def process_images_multithread(image_processor, items, process_type, wrap_in_list = True, max_workers: int = os.cpu_count()/ 2) :
|
def process_images_multithread(image_processor, items, process_type, wrap_in_list = True, max_workers: int = os.cpu_count()/ 2) :
|
||||||
if not items:
|
if not items:
|
||||||
return []
|
return []
|
||||||
max_workers = 11
|
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
# print(f"Preprocessus:{process_type} started")
|
# print(f"Preprocessus:{process_type} started")
|
||||||
if process_type in ["prephase", "upsample"]:
|
if process_type in ["prephase", "upsample"]:
|
||||||
if wrap_in_list :
|
if wrap_in_list :
|
||||||
items = [ [img] for img in items]
|
items = [ [img] for img in items]
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
if max_workers == 1:
|
||||||
futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)}
|
results = [image_processor(img) for img in items]
|
||||||
results = [None] * len(items)
|
else:
|
||||||
for future in concurrent.futures.as_completed(futures):
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
idx = futures[future]
|
futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)}
|
||||||
results[idx] = future.result()
|
results = [None] * len(items)
|
||||||
|
for future in concurrent.futures.as_completed(futures):
|
||||||
|
idx = futures[future]
|
||||||
|
results[idx] = future.result()
|
||||||
|
|
||||||
if wrap_in_list:
|
if wrap_in_list:
|
||||||
results = [ img[0] for img in results]
|
results = [ img[0] for img in results]
|
||||||
@ -3677,55 +3687,6 @@ def process_images_multithread(image_processor, items, process_type, wrap_in_lis
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canvas = False, fit_crop = False, block_size= 16, expand_scale = 2, outpainting_dims = None, inpaint_color = 127):
|
|
||||||
frame_width, frame_height = input_image.size
|
|
||||||
|
|
||||||
if fit_crop:
|
|
||||||
input_image = rescale_and_crop(input_image, width, height)
|
|
||||||
if input_mask is not None:
|
|
||||||
input_mask = rescale_and_crop(input_mask, width, height)
|
|
||||||
return input_image, input_mask
|
|
||||||
|
|
||||||
if outpainting_dims != None:
|
|
||||||
if fit_canvas != None:
|
|
||||||
frame_height, frame_width = get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims)
|
|
||||||
else:
|
|
||||||
frame_height, frame_width = height, width
|
|
||||||
|
|
||||||
if fit_canvas != None:
|
|
||||||
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size)
|
|
||||||
|
|
||||||
if outpainting_dims != None:
|
|
||||||
final_height, final_width = height, width
|
|
||||||
height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
|
|
||||||
|
|
||||||
if fit_canvas != None or outpainting_dims != None:
|
|
||||||
input_image = input_image.resize((width, height), resample=Image.Resampling.LANCZOS)
|
|
||||||
if input_mask is not None:
|
|
||||||
input_mask = input_mask.resize((width, height), resample=Image.Resampling.LANCZOS)
|
|
||||||
|
|
||||||
if expand_scale != 0 and input_mask is not None:
|
|
||||||
kernel_size = abs(expand_scale)
|
|
||||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
|
||||||
op_expand = cv2.dilate if expand_scale > 0 else cv2.erode
|
|
||||||
input_mask = np.array(input_mask)
|
|
||||||
input_mask = op_expand(input_mask, kernel, iterations=3)
|
|
||||||
input_mask = Image.fromarray(input_mask)
|
|
||||||
|
|
||||||
if outpainting_dims != None:
|
|
||||||
inpaint_color = inpaint_color / 127.5-1
|
|
||||||
image = convert_image_to_tensor(input_image)
|
|
||||||
full_frame= torch.full( (image.shape[0], final_height, final_width), inpaint_color, dtype= torch.float, device= image.device)
|
|
||||||
full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = image
|
|
||||||
input_image = convert_tensor_to_image(full_frame)
|
|
||||||
|
|
||||||
if input_mask is not None:
|
|
||||||
mask = convert_image_to_tensor(input_mask)
|
|
||||||
full_frame= torch.full( (mask.shape[0], final_height, final_width), 1, dtype= torch.float, device= mask.device)
|
|
||||||
full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = mask
|
|
||||||
input_mask = convert_tensor_to_image(full_frame)
|
|
||||||
|
|
||||||
return input_image, input_mask
|
|
||||||
def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512):
|
def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512):
|
||||||
if not input_video_path or max_frames <= 0:
|
if not input_video_path or max_frames <= 0:
|
||||||
return None, None
|
return None, None
|
||||||
@ -3780,6 +3741,8 @@ def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_fr
|
|||||||
save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None)
|
save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None)
|
||||||
return face_tensor
|
return face_tensor
|
||||||
|
|
||||||
|
def get_default_workers():
|
||||||
|
return os.cpu_count()/ 2
|
||||||
|
|
||||||
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1):
|
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1):
|
||||||
|
|
||||||
@ -3906,8 +3869,8 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
|
|||||||
return (target_frame, frame, mask)
|
return (target_frame, frame, mask)
|
||||||
else:
|
else:
|
||||||
return (target_frame, None, None)
|
return (target_frame, None, None)
|
||||||
|
max_workers = get_default_workers()
|
||||||
proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False)
|
proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False, max_workers=max_workers)
|
||||||
proc_list, proc_list_outside, proc_mask = [None] * len(proc_lists), [None] * len(proc_lists), [None] * len(proc_lists)
|
proc_list, proc_list_outside, proc_mask = [None] * len(proc_lists), [None] * len(proc_lists), [None] * len(proc_lists)
|
||||||
for frame_idx, frame_group in enumerate(proc_lists):
|
for frame_idx, frame_group in enumerate(proc_lists):
|
||||||
proc_list[frame_idx], proc_list_outside[frame_idx], proc_mask[frame_idx] = frame_group
|
proc_list[frame_idx], proc_list_outside[frame_idx], proc_mask[frame_idx] = frame_group
|
||||||
@ -3916,11 +3879,11 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
|
|||||||
mask_video = None
|
mask_video = None
|
||||||
|
|
||||||
if preproc2 != None:
|
if preproc2 != None:
|
||||||
proc_list2 = process_images_multithread(preproc2, proc_list, process_type2)
|
proc_list2 = process_images_multithread(preproc2, proc_list, process_type2, max_workers=max_workers)
|
||||||
#### to be finished ...or not
|
#### to be finished ...or not
|
||||||
proc_list = process_images_multithread(preproc, proc_list, process_type)
|
proc_list = process_images_multithread(preproc, proc_list, process_type, max_workers=max_workers)
|
||||||
if any_mask:
|
if any_mask:
|
||||||
proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask)
|
proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask, max_workers=max_workers)
|
||||||
else:
|
else:
|
||||||
proc_list_outside = proc_mask = len(proc_list) * [None]
|
proc_list_outside = proc_mask = len(proc_list) * [None]
|
||||||
|
|
||||||
@ -3938,7 +3901,7 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
|
|||||||
full_frame= torch.full( (final_height, final_width, mask.shape[-1]), 255, dtype= torch.uint8, device= mask.device)
|
full_frame= torch.full( (final_height, final_width, mask.shape[-1]), 255, dtype= torch.uint8, device= mask.device)
|
||||||
full_frame[margin_top:margin_top+height, margin_left:margin_left+width] = mask
|
full_frame[margin_top:margin_top+height, margin_left:margin_left+width] = mask
|
||||||
mask = full_frame
|
mask = full_frame
|
||||||
masks.append(mask)
|
masks.append(mask[:, :, 0:1].clone())
|
||||||
else:
|
else:
|
||||||
masked_frame = processed_img
|
masked_frame = processed_img
|
||||||
|
|
||||||
@ -3958,13 +3921,13 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
|
|||||||
proc_list[frame_no] = proc_list_outside[frame_no] = proc_mask[frame_no] = None
|
proc_list[frame_no] = proc_list_outside[frame_no] = proc_mask[frame_no] = None
|
||||||
|
|
||||||
|
|
||||||
if args.save_masks:
|
# if args.save_masks:
|
||||||
from preprocessing.dwpose.pose import save_one_video
|
# from preprocessing.dwpose.pose import save_one_video
|
||||||
saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ]
|
# saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ]
|
||||||
save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None)
|
# save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None)
|
||||||
if any_mask:
|
# if any_mask:
|
||||||
saved_masks = [mask.cpu().numpy() for mask in masks ]
|
# saved_masks = [mask.cpu().numpy() for mask in masks ]
|
||||||
save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None)
|
# save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None)
|
||||||
preproc = None
|
preproc = None
|
||||||
preproc_outside = None
|
preproc_outside = None
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@ -3972,8 +3935,10 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
|
|||||||
if pad_frames > 0:
|
if pad_frames > 0:
|
||||||
masked_frames = masked_frames[0] * pad_frames + masked_frames
|
masked_frames = masked_frames[0] * pad_frames + masked_frames
|
||||||
if any_mask: masked_frames = masks[0] * pad_frames + masks
|
if any_mask: masked_frames = masks[0] * pad_frames + masks
|
||||||
|
masked_frames = torch.stack(masked_frames).permute(-1,0,1,2).float().div_(127.5).sub_(1.)
|
||||||
|
masks = torch.stack(masks).permute(-1,0,1,2).float().div_(255) if any_mask else None
|
||||||
|
|
||||||
return torch.stack(masked_frames), torch.stack(masks) if any_mask else None
|
return masked_frames, masks
|
||||||
|
|
||||||
def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size = 16):
|
def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size = 16):
|
||||||
|
|
||||||
@ -4102,7 +4067,7 @@ def perform_spatial_upsampling(sample, spatial_upsampling):
|
|||||||
frames_to_upsample = [sample[:, i] for i in range( sample.shape[1]) ]
|
frames_to_upsample = [sample[:, i] for i in range( sample.shape[1]) ]
|
||||||
def upsample_frames(frame):
|
def upsample_frames(frame):
|
||||||
return resize_lanczos(frame, h, w).unsqueeze(1)
|
return resize_lanczos(frame, h, w).unsqueeze(1)
|
||||||
sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False), dim=1)
|
sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False, max_workers=get_default_workers()), dim=1)
|
||||||
frames_to_upsample = None
|
frames_to_upsample = None
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
@ -4609,17 +4574,13 @@ def generate_video(
|
|||||||
batch_size = 1
|
batch_size = 1
|
||||||
temp_filenames_list = []
|
temp_filenames_list = []
|
||||||
|
|
||||||
convert_image_guide_to_video = model_def.get("convert_image_guide_to_video", False)
|
if image_guide is not None and isinstance(image_guide, Image.Image):
|
||||||
if convert_image_guide_to_video:
|
video_guide = image_guide
|
||||||
if image_guide is not None and isinstance(image_guide, Image.Image):
|
image_guide = None
|
||||||
video_guide = convert_image_to_video(image_guide)
|
|
||||||
temp_filenames_list.append(video_guide)
|
|
||||||
image_guide = None
|
|
||||||
|
|
||||||
if image_mask is not None and isinstance(image_mask, Image.Image):
|
if image_mask is not None and isinstance(image_mask, Image.Image):
|
||||||
video_mask = convert_image_to_video(image_mask)
|
video_mask = image_mask
|
||||||
temp_filenames_list.append(video_mask)
|
image_mask = None
|
||||||
image_mask = None
|
|
||||||
|
|
||||||
if model_def.get("no_background_removal", False): remove_background_images_ref = 0
|
if model_def.get("no_background_removal", False): remove_background_images_ref = 0
|
||||||
|
|
||||||
@ -4711,22 +4672,12 @@ def generate_video(
|
|||||||
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
|
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
|
||||||
guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5)
|
guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5)
|
||||||
extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False)
|
extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False)
|
||||||
i2v = test_class_i2v(model_type)
|
|
||||||
diffusion_forcing = "diffusion_forcing" in model_filename
|
|
||||||
t2v = base_model_type in ["t2v"]
|
|
||||||
ltxv = "ltxv" in model_filename
|
|
||||||
vace = test_vace_module(base_model_type)
|
|
||||||
hunyuan_t2v = "hunyuan_video_720" in model_filename
|
|
||||||
hunyuan_i2v = "hunyuan_video_i2v" in model_filename
|
|
||||||
hunyuan_custom = "hunyuan_video_custom" in model_filename
|
hunyuan_custom = "hunyuan_video_custom" in model_filename
|
||||||
hunyuan_custom_audio = hunyuan_custom and "audio" in model_filename
|
hunyuan_custom_audio = hunyuan_custom and "audio" in model_filename
|
||||||
hunyuan_custom_edit = hunyuan_custom and "edit" in model_filename
|
hunyuan_custom_edit = hunyuan_custom and "edit" in model_filename
|
||||||
hunyuan_avatar = "hunyuan_video_avatar" in model_filename
|
hunyuan_avatar = "hunyuan_video_avatar" in model_filename
|
||||||
fantasy = base_model_type in ["fantasy"]
|
fantasy = base_model_type in ["fantasy"]
|
||||||
multitalk = model_def.get("multitalk_class", False)
|
multitalk = model_def.get("multitalk_class", False)
|
||||||
standin = model_def.get("standin_class", False)
|
|
||||||
infinitetalk = base_model_type in ["infinitetalk"]
|
|
||||||
animate = base_model_type in ["animate"]
|
|
||||||
|
|
||||||
if "B" in audio_prompt_type or "X" in audio_prompt_type:
|
if "B" in audio_prompt_type or "X" in audio_prompt_type:
|
||||||
from models.wan.multitalk.multitalk import parse_speakers_locations
|
from models.wan.multitalk.multitalk import parse_speakers_locations
|
||||||
@ -4763,9 +4714,9 @@ def generate_video(
|
|||||||
sliding_window_size = current_video_length
|
sliding_window_size = current_video_length
|
||||||
reuse_frames = 0
|
reuse_frames = 0
|
||||||
|
|
||||||
_, latent_size = get_model_min_frames_and_step(model_type)
|
_, _, latent_size = get_model_min_frames_and_step(model_type)
|
||||||
if diffusion_forcing: latent_size = 4
|
|
||||||
original_image_refs = image_refs
|
original_image_refs = image_refs
|
||||||
|
image_refs = None if image_refs is None else [] + image_refs # work on a copy as it is going to be modified
|
||||||
# image_refs = None
|
# image_refs = None
|
||||||
# nb_frames_positions= 0
|
# nb_frames_positions= 0
|
||||||
# Output Video Ratio Priorities:
|
# Output Video Ratio Priorities:
|
||||||
@ -4889,6 +4840,7 @@ def generate_video(
|
|||||||
initial_total_windows = 0
|
initial_total_windows = 0
|
||||||
discard_last_frames = sliding_window_discard_last_frames
|
discard_last_frames = sliding_window_discard_last_frames
|
||||||
default_requested_frames_to_generate = current_video_length
|
default_requested_frames_to_generate = current_video_length
|
||||||
|
nb_frames_positions = 0
|
||||||
if sliding_window:
|
if sliding_window:
|
||||||
initial_total_windows= compute_sliding_window_no(default_requested_frames_to_generate, sliding_window_size, discard_last_frames, reuse_frames)
|
initial_total_windows= compute_sliding_window_no(default_requested_frames_to_generate, sliding_window_size, discard_last_frames, reuse_frames)
|
||||||
current_video_length = sliding_window_size
|
current_video_length = sliding_window_size
|
||||||
@ -4907,7 +4859,7 @@ def generate_video(
|
|||||||
if repeat_no >= total_generation: break
|
if repeat_no >= total_generation: break
|
||||||
repeat_no +=1
|
repeat_no +=1
|
||||||
gen["repeat_no"] = repeat_no
|
gen["repeat_no"] = repeat_no
|
||||||
src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = src_faces = None
|
src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = None
|
||||||
prefix_video = pre_video_frame = None
|
prefix_video = pre_video_frame = None
|
||||||
source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window
|
source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window
|
||||||
source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before )
|
source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before )
|
||||||
@ -4963,7 +4915,6 @@ def generate_video(
|
|||||||
return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) )
|
return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) )
|
||||||
refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {}
|
refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {}
|
||||||
|
|
||||||
src_ref_images, src_ref_masks = image_refs, None
|
|
||||||
image_start_tensor = image_end_tensor = None
|
image_start_tensor = image_end_tensor = None
|
||||||
if window_no == 1 and (video_source is not None or image_start is not None):
|
if window_no == 1 and (video_source is not None or image_start is not None):
|
||||||
if image_start is not None:
|
if image_start is not None:
|
||||||
@ -5020,7 +4971,7 @@ def generate_video(
|
|||||||
if len(pos) > 0:
|
if len(pos) > 0:
|
||||||
if pos in ["L", "l"]:
|
if pos in ["L", "l"]:
|
||||||
cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length
|
cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length
|
||||||
if cur_end_pos >= last_frame_no and not joker_used:
|
if cur_end_pos >= last_frame_no-1 and not joker_used:
|
||||||
joker_used = True
|
joker_used = True
|
||||||
cur_end_pos = last_frame_no -1
|
cur_end_pos = last_frame_no -1
|
||||||
project_window_no += 1
|
project_window_no += 1
|
||||||
@ -5036,141 +4987,53 @@ def generate_video(
|
|||||||
frames_to_inject[pos] = image_refs[i]
|
frames_to_inject[pos] = image_refs[i]
|
||||||
|
|
||||||
|
|
||||||
|
video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
|
||||||
if video_guide is not None:
|
if video_guide is not None:
|
||||||
keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
|
keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
|
||||||
if len(error) > 0:
|
if len(error) > 0:
|
||||||
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
|
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
|
||||||
guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame
|
guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame
|
||||||
|
extra_control_frames = model_def.get("extra_control_frames", 0)
|
||||||
|
if extra_control_frames > 0 and aligned_guide_start_frame >= extra_control_frames: guide_frames_extract_start -= extra_control_frames
|
||||||
|
|
||||||
keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else []
|
keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else []
|
||||||
keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ]
|
keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ]
|
||||||
guide_frames_extract_count = len(keep_frames_parsed)
|
guide_frames_extract_count = len(keep_frames_parsed)
|
||||||
|
|
||||||
|
# Extract Faces to video
|
||||||
if "B" in video_prompt_type:
|
if "B" in video_prompt_type:
|
||||||
send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")])
|
send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")])
|
||||||
src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps)
|
src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps)
|
||||||
if src_faces is not None and src_faces.shape[1] < current_video_length:
|
if src_faces is not None and src_faces.shape[1] < current_video_length:
|
||||||
src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1)
|
src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1)
|
||||||
|
|
||||||
if vace or animate:
|
# Sparse Video to Video
|
||||||
video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
|
sparse_video_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, target_fps = fps, return_PIL = True) if "R" in video_prompt_type else None
|
||||||
context_scale = [ control_net_weight]
|
|
||||||
if "V" in video_prompt_type:
|
|
||||||
process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
|
|
||||||
preprocess_type, preprocess_type2 = "raw", None
|
|
||||||
for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")):
|
|
||||||
if process_num == 0:
|
|
||||||
preprocess_type = process_map_video_guide.get(process_letter, "raw")
|
|
||||||
else:
|
|
||||||
preprocess_type2 = process_map_video_guide.get(process_letter, None)
|
|
||||||
status_info = "Extracting " + processes_names[preprocess_type]
|
|
||||||
extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask])
|
|
||||||
if len(extra_process_list) == 1:
|
|
||||||
status_info += " and " + processes_names[extra_process_list[0]]
|
|
||||||
elif len(extra_process_list) == 2:
|
|
||||||
status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]]
|
|
||||||
if preprocess_type2 is not None:
|
|
||||||
context_scale = [ control_net_weight /2, control_net_weight2 /2]
|
|
||||||
send_cmd("progress", [0, get_latest_status(state, status_info)])
|
|
||||||
inpaint_color = 0 if preprocess_type=="pose" else guide_inpaint_color
|
|
||||||
video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color )
|
|
||||||
if preprocess_type2 != None:
|
|
||||||
video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 )
|
|
||||||
|
|
||||||
if video_guide_processed != None:
|
# Generic Video Preprocessing
|
||||||
if sample_fit_canvas != None:
|
process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
|
||||||
image_size = video_guide_processed.shape[-3: -1]
|
preprocess_type, preprocess_type2 = "raw", None
|
||||||
sample_fit_canvas = None
|
for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")):
|
||||||
refresh_preview["video_guide"] = Image.fromarray(video_guide_processed[0].cpu().numpy())
|
if process_num == 0:
|
||||||
if video_guide_processed2 != None:
|
preprocess_type = process_map_video_guide.get(process_letter, "raw")
|
||||||
refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())]
|
|
||||||
if video_mask_processed != None:
|
|
||||||
refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy())
|
|
||||||
|
|
||||||
frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame]
|
|
||||||
|
|
||||||
if not vace and (any_letters(video_prompt_type ,"FV") or model_def.get("forced_guide_mask_inputs", False)):
|
|
||||||
any_mask = True
|
|
||||||
any_guide_padding = model_def.get("pad_guide_video", False)
|
|
||||||
from shared.utils.utils import prepare_video_guide_and_mask
|
|
||||||
src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed, video_guide_processed2],
|
|
||||||
[video_mask_processed, video_mask_processed2],
|
|
||||||
pre_video_guide, image_size, current_video_length, latent_size,
|
|
||||||
any_mask, any_guide_padding, guide_inpaint_color, extract_guide_from_window_start,
|
|
||||||
keep_frames_parsed, frames_to_inject_parsed , outpainting_dims)
|
|
||||||
|
|
||||||
src_video, src_video2 = src_videos
|
|
||||||
src_mask, src_mask2 = src_masks
|
|
||||||
if src_video is None:
|
|
||||||
abort = True
|
|
||||||
break
|
|
||||||
if src_faces is not None:
|
|
||||||
if src_faces.shape[1] < src_video.shape[1]:
|
|
||||||
src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1)
|
|
||||||
else:
|
|
||||||
src_faces = src_faces[:, :src_video.shape[1]]
|
|
||||||
if args.save_masks:
|
|
||||||
save_video( src_video, "masked_frames.mp4", fps)
|
|
||||||
if src_video2 is not None:
|
|
||||||
save_video( src_video2, "masked_frames2.mp4", fps)
|
|
||||||
if any_mask:
|
|
||||||
save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
|
|
||||||
|
|
||||||
elif ltxv:
|
|
||||||
preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
|
|
||||||
status_info = "Extracting " + processes_names[preprocess_type]
|
|
||||||
send_cmd("progress", [0, get_latest_status(state, status_info)])
|
|
||||||
# start one frame ealier to facilitate latents merging later
|
|
||||||
src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size )
|
|
||||||
if src_video != None:
|
|
||||||
src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ]
|
|
||||||
refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
|
|
||||||
refresh_preview["video_mask"] = None
|
|
||||||
src_video = src_video.permute(3, 0, 1, 2)
|
|
||||||
src_video = src_video.float().div_(127.5).sub_(1.) # c, f, h, w
|
|
||||||
if sample_fit_canvas != None:
|
|
||||||
image_size = src_video.shape[-2:]
|
|
||||||
sample_fit_canvas = None
|
|
||||||
|
|
||||||
elif hunyuan_custom_edit:
|
|
||||||
if "P" in video_prompt_type:
|
|
||||||
progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")]
|
|
||||||
else:
|
else:
|
||||||
progress_args = [0, get_latest_status(state,"Extracting Video and Mask")]
|
preprocess_type2 = process_map_video_guide.get(process_letter, None)
|
||||||
|
status_info = "Extracting " + processes_names[preprocess_type]
|
||||||
|
extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask])
|
||||||
|
if len(extra_process_list) == 1:
|
||||||
|
status_info += " and " + processes_names[extra_process_list[0]]
|
||||||
|
elif len(extra_process_list) == 2:
|
||||||
|
status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]]
|
||||||
|
context_scale = [control_net_weight /2, control_net_weight2 /2] if preprocess_type2 is not None else [control_net_weight]
|
||||||
|
if not (preprocess_type == "identity" and preprocess_type2 is None and video_mask is None):send_cmd("progress", [0, get_latest_status(state, status_info)])
|
||||||
|
inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask == "inpaint" else guide_inpaint_color
|
||||||
|
video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide if sparse_video_image is None else sparse_video_image, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color, block_size = block_size )
|
||||||
|
if preprocess_type2 != None:
|
||||||
|
video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2, block_size = block_size )
|
||||||
|
|
||||||
send_cmd("progress", progress_args)
|
if video_guide_processed is not None and sample_fit_canvas is not None:
|
||||||
src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0)
|
image_size = video_guide_processed.shape[-2:]
|
||||||
refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
|
|
||||||
if src_mask != None:
|
|
||||||
refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy())
|
|
||||||
|
|
||||||
elif "R" in video_prompt_type: # sparse video to video
|
|
||||||
src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True)
|
|
||||||
src_image, _, _ = calculate_dimensions_and_resize_image(src_image, image_size[0], image_size[1 ], sample_fit_canvas, fit_crop, block_size = block_size)
|
|
||||||
refresh_preview["video_guide"] = src_image
|
|
||||||
src_video = convert_image_to_tensor(src_image).unsqueeze(1)
|
|
||||||
if sample_fit_canvas != None:
|
|
||||||
image_size = src_video.shape[-2:]
|
|
||||||
sample_fit_canvas = None
|
|
||||||
|
|
||||||
else: # video to video
|
|
||||||
video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size= block_size)
|
|
||||||
if video_guide_processed is None:
|
|
||||||
src_video = pre_video_guide
|
|
||||||
else:
|
|
||||||
if sample_fit_canvas != None:
|
|
||||||
image_size = video_guide_processed.shape[-3: -1]
|
|
||||||
sample_fit_canvas = None
|
|
||||||
src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2)
|
|
||||||
if pre_video_guide != None:
|
|
||||||
src_video = torch.cat( [pre_video_guide, src_video], dim=1)
|
|
||||||
elif image_guide is not None:
|
|
||||||
new_image_guide, new_image_mask = preprocess_image_with_mask(image_guide, image_mask, image_size[0], image_size[1], fit_canvas = sample_fit_canvas, fit_crop= fit_crop, block_size= block_size, expand_scale = mask_expand, outpainting_dims=outpainting_dims)
|
|
||||||
if sample_fit_canvas is not None:
|
|
||||||
image_size = (new_image_guide.size[1], new_image_guide.size[0])
|
|
||||||
sample_fit_canvas = None
|
sample_fit_canvas = None
|
||||||
refresh_preview["image_guide"] = new_image_guide
|
|
||||||
if new_image_mask is not None:
|
|
||||||
refresh_preview["image_mask"] = new_image_mask
|
|
||||||
|
|
||||||
if window_no == 1 and image_refs is not None and len(image_refs) > 0:
|
if window_no == 1 and image_refs is not None and len(image_refs) > 0:
|
||||||
if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) :
|
if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) :
|
||||||
@ -5193,44 +5056,67 @@ def generate_video(
|
|||||||
refresh_preview["image_refs"] = image_refs
|
refresh_preview["image_refs"] = image_refs
|
||||||
|
|
||||||
if len(image_refs) > nb_frames_positions:
|
if len(image_refs) > nb_frames_positions:
|
||||||
|
src_ref_images = image_refs[nb_frames_positions:]
|
||||||
if remove_background_images_ref > 0:
|
if remove_background_images_ref > 0:
|
||||||
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
|
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
|
||||||
# keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested
|
|
||||||
image_refs[nb_frames_positions:], src_ref_masks = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0],
|
src_ref_images, src_ref_masks = resize_and_remove_background(src_ref_images , image_size[1], image_size[0],
|
||||||
remove_background_images_ref > 0, any_background_ref,
|
remove_background_images_ref > 0, any_background_ref,
|
||||||
fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1,
|
fit_into_canvas= model_def.get("fit_into_canvas_image_refs", 1),
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
outpainting_dims =outpainting_dims,
|
outpainting_dims =outpainting_dims,
|
||||||
background_ref_outpainted = model_def.get("background_ref_outpainted", True) )
|
background_ref_outpainted = model_def.get("background_ref_outpainted", True),
|
||||||
refresh_preview["image_refs"] = image_refs
|
return_tensor= model_def.get("return_image_refs_tensor", False) )
|
||||||
|
|
||||||
|
|
||||||
if vace :
|
frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame]
|
||||||
image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
|
if video_guide is not None or len(frames_to_inject_parsed) > 0 or model_def.get("forced_guide_mask_inputs", False):
|
||||||
|
any_mask = video_mask is not None or model_def.get("forced_guide_mask_inputs", False)
|
||||||
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2],
|
any_guide_padding = model_def.get("pad_guide_video", False)
|
||||||
[video_mask_processed] if video_guide_processed2 == None else [video_mask_processed, video_mask_processed2],
|
from shared.utils.utils import prepare_video_guide_and_mask
|
||||||
[image_refs_copy] if video_guide_processed2 == None else [image_refs_copy, image_refs_copy],
|
src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed] + ([] if video_guide_processed2 is None else [video_guide_processed2]),
|
||||||
current_video_length, image_size = image_size, device ="cpu",
|
[video_mask_processed] + ([] if video_mask_processed2 is None else [video_mask_processed2]),
|
||||||
keep_video_guide_frames=keep_frames_parsed,
|
None if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else pre_video_guide,
|
||||||
pre_src_video = [pre_video_guide] if video_guide_processed2 == None else [pre_video_guide, pre_video_guide],
|
image_size, current_video_length, latent_size,
|
||||||
inject_frames= frames_to_inject_parsed,
|
any_mask, any_guide_padding, guide_inpaint_color,
|
||||||
outpainting_dims = outpainting_dims,
|
keep_frames_parsed, frames_to_inject_parsed , outpainting_dims)
|
||||||
any_background_ref = any_background_ref
|
video_guide_processed = video_guide_processed2 = video_mask_processed = video_mask_processed2 = None
|
||||||
)
|
if len(src_videos) == 1:
|
||||||
if len(frames_to_inject_parsed) or any_background_ref:
|
src_video, src_video2, src_mask, src_mask2 = src_videos[0], None, src_masks[0], None
|
||||||
new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + 0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
|
else:
|
||||||
if any_background_ref:
|
src_video, src_video2 = src_videos
|
||||||
new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:]
|
src_mask, src_mask2 = src_masks
|
||||||
|
src_videos = src_masks = None
|
||||||
|
if src_video is None:
|
||||||
|
abort = True
|
||||||
|
break
|
||||||
|
if src_faces is not None:
|
||||||
|
if src_faces.shape[1] < src_video.shape[1]:
|
||||||
|
src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1)
|
||||||
else:
|
else:
|
||||||
new_image_refs += image_refs[nb_frames_positions:]
|
src_faces = src_faces[:, :src_video.shape[1]]
|
||||||
refresh_preview["image_refs"] = new_image_refs
|
if video_guide is not None or len(frames_to_inject_parsed) > 0:
|
||||||
new_image_refs = None
|
if args.save_masks:
|
||||||
|
if src_video is not None: save_video( src_video, "masked_frames.mp4", fps)
|
||||||
if sample_fit_canvas != None:
|
if src_video2 is not None: save_video( src_video2, "masked_frames2.mp4", fps)
|
||||||
image_size = src_video[0].shape[-2:]
|
if any_mask: save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
|
||||||
sample_fit_canvas = None
|
if video_guide is not None:
|
||||||
|
preview_frame_no = 0 if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else (guide_start_frame - window_start_frame)
|
||||||
|
refresh_preview["video_guide"] = convert_tensor_to_image(src_video, preview_frame_no)
|
||||||
|
if src_video2 is not None:
|
||||||
|
refresh_preview["video_guide"] = [refresh_preview["video_guide"], convert_tensor_to_image(src_video2, preview_frame_no)]
|
||||||
|
if src_mask is not None and video_mask is not None:
|
||||||
|
refresh_preview["video_mask"] = convert_tensor_to_image(src_mask, preview_frame_no, mask_levels = True)
|
||||||
|
|
||||||
|
if src_ref_images is not None or nb_frames_positions:
|
||||||
|
if len(frames_to_inject_parsed):
|
||||||
|
new_image_refs = [convert_tensor_to_image(src_video, frame_no + (0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame)) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
|
||||||
|
else:
|
||||||
|
new_image_refs = []
|
||||||
|
if src_ref_images is not None:
|
||||||
|
new_image_refs += [convert_tensor_to_image(img) if torch.is_tensor(img) else img for img in src_ref_images ]
|
||||||
|
refresh_preview["image_refs"] = new_image_refs
|
||||||
|
new_image_refs = None
|
||||||
|
|
||||||
if len(refresh_preview) > 0:
|
if len(refresh_preview) > 0:
|
||||||
new_inputs= locals()
|
new_inputs= locals()
|
||||||
@ -5339,8 +5225,6 @@ def generate_video(
|
|||||||
pre_video_frame = pre_video_frame,
|
pre_video_frame = pre_video_frame,
|
||||||
original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [],
|
original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [],
|
||||||
image_refs_relative_size = image_refs_relative_size,
|
image_refs_relative_size = image_refs_relative_size,
|
||||||
image_guide= new_image_guide,
|
|
||||||
image_mask= new_image_mask,
|
|
||||||
outpainting_dims = outpainting_dims,
|
outpainting_dims = outpainting_dims,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -6320,7 +6204,10 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
|
|||||||
pop += ["image_refs_relative_size"]
|
pop += ["image_refs_relative_size"]
|
||||||
|
|
||||||
if not vace:
|
if not vace:
|
||||||
pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2"]
|
pop += ["frames_positions", "control_net_weight", "control_net_weight2"]
|
||||||
|
|
||||||
|
if model_def.get("video_guide_outpainting", None) is None:
|
||||||
|
pop += ["video_guide_outpainting"]
|
||||||
|
|
||||||
if not (vace or t2v):
|
if not (vace or t2v):
|
||||||
pop += ["min_frames_if_references"]
|
pop += ["min_frames_if_references"]
|
||||||
@ -6506,13 +6393,6 @@ def eject_video_from_gallery(state, input_file_list, choice):
|
|||||||
choice = min(choice, len(file_list))
|
choice = min(choice, len(file_list))
|
||||||
return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0)
|
return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0)
|
||||||
|
|
||||||
def has_video_file_extension(filename):
|
|
||||||
extension = os.path.splitext(filename)[-1].lower()
|
|
||||||
return extension in [".mp4"]
|
|
||||||
|
|
||||||
def has_image_file_extension(filename):
|
|
||||||
extension = os.path.splitext(filename)[-1].lower()
|
|
||||||
return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"]
|
|
||||||
def add_videos_to_gallery(state, input_file_list, choice, files_to_load):
|
def add_videos_to_gallery(state, input_file_list, choice, files_to_load):
|
||||||
gen = get_gen_info(state)
|
gen = get_gen_info(state)
|
||||||
if files_to_load == None:
|
if files_to_load == None:
|
||||||
@ -7881,7 +7761,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
elif recammaster:
|
elif recammaster:
|
||||||
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", get_max_frames(81)), step=4, label="Number of frames (16 = 1s), locked", interactive= False, visible = True)
|
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", get_max_frames(81)), step=4, label="Number of frames (16 = 1s), locked", interactive= False, visible = True)
|
||||||
else:
|
else:
|
||||||
min_frames, frames_step = get_model_min_frames_and_step(base_model_type)
|
min_frames, frames_step, _ = get_model_min_frames_and_step(base_model_type)
|
||||||
|
|
||||||
current_video_length = ui_defaults.get("video_length", 81 if get_model_family(base_model_type)=="wan" else 97)
|
current_video_length = ui_defaults.get("video_length", 81 if get_model_family(base_model_type)=="wan" else 97)
|
||||||
|
|
||||||
@ -8059,7 +7939,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
MMAudio_neg_prompt = gr.Text(ui_defaults.get("MMAudio_neg_prompt", ""), label="Negative Prompt (1 or 2 keywords)")
|
MMAudio_neg_prompt = gr.Text(ui_defaults.get("MMAudio_neg_prompt", ""), label="Negative Prompt (1 or 2 keywords)")
|
||||||
|
|
||||||
|
|
||||||
with gr.Column(visible = (t2v or vace) and not fantasy) as audio_prompt_type_remux_row:
|
with gr.Column(visible = any_control_video) as audio_prompt_type_remux_row:
|
||||||
gr.Markdown("<B>You may transfer the existing audio tracks of a Control Video</B>")
|
gr.Markdown("<B>You may transfer the existing audio tracks of a Control Video</B>")
|
||||||
audio_prompt_type_remux = gr.Dropdown(
|
audio_prompt_type_remux = gr.Dropdown(
|
||||||
choices=[
|
choices=[
|
||||||
@ -8284,16 +8164,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
video_info = gr.HTML(visible=True, min_height=100, value=get_default_video_info())
|
video_info = gr.HTML(visible=True, min_height=100, value=get_default_video_info())
|
||||||
with gr.Row(**default_visibility) as video_buttons_row:
|
with gr.Row(**default_visibility) as video_buttons_row:
|
||||||
video_info_extract_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm")
|
video_info_extract_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm")
|
||||||
video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video )
|
|
||||||
video_info_to_video_source_btn = gr.Button("To Video Source", min_width= 1, size ="sm", visible = any_video_source)
|
video_info_to_video_source_btn = gr.Button("To Video Source", min_width= 1, size ="sm", visible = any_video_source)
|
||||||
|
video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video )
|
||||||
video_info_eject_video_btn = gr.Button("Eject Video", min_width= 1, size ="sm")
|
video_info_eject_video_btn = gr.Button("Eject Video", min_width= 1, size ="sm")
|
||||||
with gr.Row(**default_visibility) as image_buttons_row:
|
with gr.Row(**default_visibility) as image_buttons_row:
|
||||||
video_info_extract_image_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm")
|
video_info_extract_image_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm")
|
||||||
video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", min_width= 1, visible = any_start_image )
|
video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", min_width= 1, visible = any_start_image )
|
||||||
video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", min_width= 1, visible = any_end_image)
|
video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", min_width= 1, visible = any_end_image)
|
||||||
video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image )
|
|
||||||
video_info_to_image_mask_btn = gr.Button("To Mask Image", min_width= 1, size ="sm", visible = any_image_mask and False)
|
video_info_to_image_mask_btn = gr.Button("To Mask Image", min_width= 1, size ="sm", visible = any_image_mask and False)
|
||||||
video_info_to_reference_image_btn = gr.Button("To Reference Image", min_width= 1, size ="sm", visible = any_reference_image)
|
video_info_to_reference_image_btn = gr.Button("To Reference Image", min_width= 1, size ="sm", visible = any_reference_image)
|
||||||
|
video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image )
|
||||||
video_info_eject_image_btn = gr.Button("Eject Image", min_width= 1, size ="sm")
|
video_info_eject_image_btn = gr.Button("Eject Image", min_width= 1, size ="sm")
|
||||||
with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab:
|
with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab:
|
||||||
with gr.Group(elem_classes= "postprocess"):
|
with gr.Group(elem_classes= "postprocess"):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user