mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Vace Contenders are in Town
This commit is contained in:
		
							parent
							
								
									84010bd861
								
							
						
					
					
						commit
						e28c95ae91
					
				@ -20,6 +20,14 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
 | 
			
		||||
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
 | 
			
		||||
 | 
			
		||||
## 🔥 Latest Updates : 
 | 
			
		||||
### September 23 2025: WanGP v8.7 - Here Are Two New Contenders in the Vace Arena !
 | 
			
		||||
 | 
			
		||||
So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
 | 
			
		||||
- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion tranfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
 | 
			
		||||
 | 
			
		||||
- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package. 
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### September 15 2025: WanGP v8.6 - Attack of the Clones
 | 
			
		||||
 | 
			
		||||
- The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**)
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,7 @@
 | 
			
		||||
    "model": {
 | 
			
		||||
        "name": "Wan2.2 Lucy Edit 5B",
 | 
			
		||||
        "architecture": "lucy_edit",
 | 
			
		||||
        "description": "Lucy Edit Dev is a video editing model that performs instruction-guided edits on videos using free-text prompts \u2014 it supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.",
 | 
			
		||||
        "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.",
 | 
			
		||||
        "URLs": [
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors",
 | 
			
		||||
@ -10,6 +10,7 @@
 | 
			
		||||
        ],
 | 
			
		||||
        "group": "wan2_2"
 | 
			
		||||
    },
 | 
			
		||||
	"prompt": "change the clothes to red",	
 | 
			
		||||
    "video_length": 81,
 | 
			
		||||
    "guidance_scale": 5,
 | 
			
		||||
    "flow_shift": 5,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										16
									
								
								defaults/lucy_edit_fastwan.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								defaults/lucy_edit_fastwan.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,16 @@
 | 
			
		||||
{
 | 
			
		||||
    "model": {
 | 
			
		||||
        "name": "Wan2.2 FastWan Lucy Edit 5B",
 | 
			
		||||
        "architecture": "lucy_edit",
 | 
			
		||||
        "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.",
 | 
			
		||||
        "URLs": "lucy_edit",
 | 
			
		||||
        "group": "wan2_2",
 | 
			
		||||
		"loras": "ti2v_2_2"
 | 
			
		||||
    },
 | 
			
		||||
	"prompt": "change the clothes to red",
 | 
			
		||||
    "video_length": 81,
 | 
			
		||||
    "guidance_scale": 1,
 | 
			
		||||
    "flow_shift": 3,
 | 
			
		||||
    "num_inference_steps": 5,
 | 
			
		||||
    "resolution": "1280x720"
 | 
			
		||||
}
 | 
			
		||||
@ -56,7 +56,7 @@ class family_handler():
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        extra_model_def["lock_image_refs_ratios"] = True
 | 
			
		||||
        extra_model_def["fit_into_canvas_image_refs"] = 0
 | 
			
		||||
 | 
			
		||||
        return extra_model_def
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -142,8 +142,8 @@ class model_factory:
 | 
			
		||||
            n_prompt: str = None,
 | 
			
		||||
            sampling_steps: int = 20,
 | 
			
		||||
            input_ref_images = None,
 | 
			
		||||
            image_guide= None,
 | 
			
		||||
            image_mask= None,
 | 
			
		||||
            input_frames= None,
 | 
			
		||||
            input_masks= None,
 | 
			
		||||
            width= 832,
 | 
			
		||||
            height=480,
 | 
			
		||||
            embedded_guidance_scale: float = 2.5,
 | 
			
		||||
@ -197,10 +197,12 @@ class model_factory:
 | 
			
		||||
                    for new_img in input_ref_images[1:]:
 | 
			
		||||
                        stiched = stitch_images(stiched, new_img)
 | 
			
		||||
                    input_ref_images  = [stiched]
 | 
			
		||||
            elif image_guide is not None:
 | 
			
		||||
                input_ref_images = [image_guide] 
 | 
			
		||||
            elif input_frames is not None:
 | 
			
		||||
                input_ref_images = [convert_tensor_to_image(input_frames) ] 
 | 
			
		||||
            else:
 | 
			
		||||
                input_ref_images = None
 | 
			
		||||
            image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True) 
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
            if self.name in ['flux-dev-uso', 'flux-dev-umo']  :
 | 
			
		||||
                inp, height, width = prepare_multi_ip(
 | 
			
		||||
@ -253,8 +255,8 @@ class model_factory:
 | 
			
		||||
            if image_mask is not None:
 | 
			
		||||
                from shared.utils.utils import convert_image_to_tensor
 | 
			
		||||
                img_msk_rebuilt = inp["img_msk_rebuilt"]
 | 
			
		||||
                img= convert_image_to_tensor(image_guide) 
 | 
			
		||||
                x = img.squeeze(2) * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt 
 | 
			
		||||
                img= input_frames.squeeze(1).unsqueeze(0) # convert_image_to_tensor(image_guide) 
 | 
			
		||||
                x = img * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt 
 | 
			
		||||
 | 
			
		||||
            x = x.clamp(-1, 1)
 | 
			
		||||
            x = x.transpose(0, 1)
 | 
			
		||||
 | 
			
		||||
@ -865,7 +865,7 @@ class HunyuanVideoSampler(Inference):
 | 
			
		||||
                freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
 | 
			
		||||
            else:
 | 
			
		||||
                if input_frames != None:
 | 
			
		||||
                    target_height, target_width = input_frames.shape[-3:-1]
 | 
			
		||||
                    target_height, target_width = input_frames.shape[-2:]
 | 
			
		||||
                elif input_video != None:
 | 
			
		||||
                    target_height, target_width = input_video.shape[-2:]
 | 
			
		||||
 | 
			
		||||
@ -894,9 +894,10 @@ class HunyuanVideoSampler(Inference):
 | 
			
		||||
            pixel_value_bg = input_video.unsqueeze(0)
 | 
			
		||||
            pixel_value_mask =  torch.zeros_like(input_video).unsqueeze(0)
 | 
			
		||||
        if input_frames != None:
 | 
			
		||||
            pixel_value_video_bg = input_frames.permute(-1,0,1,2).unsqueeze(0).float()
 | 
			
		||||
            pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float()
 | 
			
		||||
            pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.)
 | 
			
		||||
            pixel_value_video_bg = input_frames.unsqueeze(0) #.permute(-1,0,1,2).unsqueeze(0).float()
 | 
			
		||||
            # pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.)
 | 
			
		||||
            # pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float()
 | 
			
		||||
            pixel_value_video_mask = input_masks.repeat(3,1,1,1).unsqueeze(0)
 | 
			
		||||
            if input_video != None:
 | 
			
		||||
                pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2)
 | 
			
		||||
                pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2)
 | 
			
		||||
@ -908,10 +909,11 @@ class HunyuanVideoSampler(Inference):
 | 
			
		||||
            if pixel_value_bg.shape[2] < frame_num:
 | 
			
		||||
                padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] +  list(pixel_value_bg.shape[3:])  
 | 
			
		||||
                pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2)
 | 
			
		||||
                pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
 | 
			
		||||
                # pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
 | 
			
		||||
                pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 1, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
 | 
			
		||||
 | 
			
		||||
            bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample()                
 | 
			
		||||
            pixel_value_mask = pixel_value_mask.div_(127.5).add_(-1.)             
 | 
			
		||||
            pixel_value_mask = pixel_value_mask.mul_(2).add_(-1.)    # unmasked pixels is -1 (no 0 as usual) and masked is 1 
 | 
			
		||||
            mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample()
 | 
			
		||||
            bg_latents = torch.cat([bg_latents, mask_latents], dim=1)
 | 
			
		||||
            bg_latents.mul_(self.vae.config.scaling_factor)
 | 
			
		||||
 | 
			
		||||
@ -35,6 +35,8 @@ class family_handler():
 | 
			
		||||
            "selection": ["", "A", "NA", "XA", "XNA"],
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        extra_model_def["extra_control_frames"] = 1
 | 
			
		||||
        extra_model_def["dont_cat_preguide"]= True
 | 
			
		||||
        return extra_model_def
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,7 @@ class family_handler():
 | 
			
		||||
                            ("Default", "default"),
 | 
			
		||||
                            ("Lightning", "lightning")],
 | 
			
		||||
            "guidance_max_phases" : 1,
 | 
			
		||||
            "lock_image_refs_ratios": True,
 | 
			
		||||
            "fit_into_canvas_image_refs": 0,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["qwen_image_edit_20B"]: 
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,7 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
 | 
			
		||||
from diffusers import FlowMatchEulerDiscreteScheduler
 | 
			
		||||
from .pipeline_qwenimage import QwenImagePipeline
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from shared.utils.utils import calculate_new_dimensions
 | 
			
		||||
from shared.utils.utils import calculate_new_dimensions, convert_tensor_to_image
 | 
			
		||||
 | 
			
		||||
def stitch_images(img1, img2):
 | 
			
		||||
    # Resize img2 to match img1's height
 | 
			
		||||
@ -103,8 +103,8 @@ class model_factory():
 | 
			
		||||
        n_prompt = None,
 | 
			
		||||
        sampling_steps: int = 20,
 | 
			
		||||
        input_ref_images = None,
 | 
			
		||||
        image_guide= None,
 | 
			
		||||
        image_mask= None,
 | 
			
		||||
        input_frames= None,
 | 
			
		||||
        input_masks= None,
 | 
			
		||||
        width= 832,
 | 
			
		||||
        height=480,
 | 
			
		||||
        guide_scale: float = 4,
 | 
			
		||||
@ -179,8 +179,10 @@ class model_factory():
 | 
			
		||||
 | 
			
		||||
        if n_prompt is None or len(n_prompt) == 0:
 | 
			
		||||
            n_prompt=  "text, watermark, copyright, blurry, low resolution"
 | 
			
		||||
        if image_guide is not None:
 | 
			
		||||
            input_ref_images = [image_guide] 
 | 
			
		||||
 | 
			
		||||
        image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True) 
 | 
			
		||||
        if input_frames is not None:
 | 
			
		||||
            input_ref_images = [convert_tensor_to_image(input_frames) ] 
 | 
			
		||||
        elif input_ref_images is not None:
 | 
			
		||||
            # image stiching method
 | 
			
		||||
            stiched = input_ref_images[0]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										143
									
								
								models/wan/animate/animate_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								models/wan/animate/animate_utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,143 @@
 | 
			
		||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
			
		||||
import torch
 | 
			
		||||
import numbers
 | 
			
		||||
from peft import LoraConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"):
 | 
			
		||||
    target_modules = []
 | 
			
		||||
    for name, module in transformer.named_modules():
 | 
			
		||||
        if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear):
 | 
			
		||||
            target_modules.append(name)
 | 
			
		||||
 | 
			
		||||
    transformer_lora_config = LoraConfig(
 | 
			
		||||
        r=rank,
 | 
			
		||||
        lora_alpha=alpha,
 | 
			
		||||
        init_lora_weights=init_lora_weights,
 | 
			
		||||
        target_modules=target_modules,
 | 
			
		||||
    )
 | 
			
		||||
    return transformer_lora_config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TensorList(object):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, tensors):
 | 
			
		||||
        """
 | 
			
		||||
        tensors: a list of torch.Tensor objects. No need to have uniform shape.
 | 
			
		||||
        """
 | 
			
		||||
        assert isinstance(tensors, (list, tuple))
 | 
			
		||||
        assert all(isinstance(u, torch.Tensor) for u in tensors)
 | 
			
		||||
        assert len(set([u.ndim for u in tensors])) == 1
 | 
			
		||||
        assert len(set([u.dtype for u in tensors])) == 1
 | 
			
		||||
        assert len(set([u.device for u in tensors])) == 1
 | 
			
		||||
        self.tensors = tensors
 | 
			
		||||
    
 | 
			
		||||
    def to(self, *args, **kwargs):
 | 
			
		||||
        return TensorList([u.to(*args, **kwargs) for u in self.tensors])
 | 
			
		||||
    
 | 
			
		||||
    def size(self, dim):
 | 
			
		||||
        assert dim == 0, 'only support get the 0th size'
 | 
			
		||||
        return len(self.tensors)
 | 
			
		||||
    
 | 
			
		||||
    def pow(self, *args, **kwargs):
 | 
			
		||||
        return TensorList([u.pow(*args, **kwargs) for u in self.tensors])
 | 
			
		||||
    
 | 
			
		||||
    def squeeze(self, dim):
 | 
			
		||||
        assert dim != 0
 | 
			
		||||
        if dim > 0:
 | 
			
		||||
            dim -= 1
 | 
			
		||||
        return TensorList([u.squeeze(dim) for u in self.tensors])
 | 
			
		||||
    
 | 
			
		||||
    def type(self, *args, **kwargs):
 | 
			
		||||
        return TensorList([u.type(*args, **kwargs) for u in self.tensors])
 | 
			
		||||
    
 | 
			
		||||
    def type_as(self, other):
 | 
			
		||||
        assert isinstance(other, (torch.Tensor, TensorList))
 | 
			
		||||
        if isinstance(other, torch.Tensor):
 | 
			
		||||
            return TensorList([u.type_as(other) for u in self.tensors])
 | 
			
		||||
        else:
 | 
			
		||||
            return TensorList([u.type(other.dtype) for u in self.tensors])
 | 
			
		||||
    
 | 
			
		||||
    @property
 | 
			
		||||
    def dtype(self):
 | 
			
		||||
        return self.tensors[0].dtype
 | 
			
		||||
    
 | 
			
		||||
    @property
 | 
			
		||||
    def device(self):
 | 
			
		||||
        return self.tensors[0].device
 | 
			
		||||
    
 | 
			
		||||
    @property
 | 
			
		||||
    def ndim(self):
 | 
			
		||||
        return 1 + self.tensors[0].ndim
 | 
			
		||||
    
 | 
			
		||||
    def __getitem__(self, index):
 | 
			
		||||
        return self.tensors[index]
 | 
			
		||||
    
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(self.tensors)
 | 
			
		||||
    
 | 
			
		||||
    def __add__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: u + v)
 | 
			
		||||
    
 | 
			
		||||
    def __radd__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: v + u)
 | 
			
		||||
    
 | 
			
		||||
    def __sub__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: u - v)
 | 
			
		||||
    
 | 
			
		||||
    def __rsub__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: v - u)
 | 
			
		||||
    
 | 
			
		||||
    def __mul__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: u * v)
 | 
			
		||||
    
 | 
			
		||||
    def __rmul__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: v * u)
 | 
			
		||||
    
 | 
			
		||||
    def __floordiv__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: u // v)
 | 
			
		||||
    
 | 
			
		||||
    def __truediv__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: u / v)
 | 
			
		||||
    
 | 
			
		||||
    def __rfloordiv__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: v // u)
 | 
			
		||||
    
 | 
			
		||||
    def __rtruediv__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: v / u)
 | 
			
		||||
    
 | 
			
		||||
    def __pow__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: u ** v)
 | 
			
		||||
    
 | 
			
		||||
    def __rpow__(self, other):
 | 
			
		||||
        return self._apply(other, lambda u, v: v ** u)
 | 
			
		||||
    
 | 
			
		||||
    def __neg__(self):
 | 
			
		||||
        return TensorList([-u for u in self.tensors])
 | 
			
		||||
    
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        for tensor in self.tensors:
 | 
			
		||||
            yield tensor
 | 
			
		||||
    
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return 'TensorList: \n' + repr(self.tensors)
 | 
			
		||||
 | 
			
		||||
    def _apply(self, other, op):
 | 
			
		||||
        if isinstance(other, (list, tuple, TensorList)) or (
 | 
			
		||||
            isinstance(other, torch.Tensor) and (
 | 
			
		||||
                other.numel() > 1 or other.ndim > 1
 | 
			
		||||
            )
 | 
			
		||||
        ):
 | 
			
		||||
            assert len(other) == len(self.tensors)
 | 
			
		||||
            return TensorList([op(u, v) for u, v in zip(self.tensors, other)])
 | 
			
		||||
        elif isinstance(other, numbers.Number) or (
 | 
			
		||||
            isinstance(other, torch.Tensor) and (
 | 
			
		||||
                other.numel() == 1 and other.ndim <= 1
 | 
			
		||||
            )
 | 
			
		||||
        ):
 | 
			
		||||
            return TensorList([op(u, other) for u in self.tensors])
 | 
			
		||||
        else:
 | 
			
		||||
            raise TypeError(
 | 
			
		||||
                f'unsupported operand for *: "TensorList" and "{type(other)}"'
 | 
			
		||||
            )
 | 
			
		||||
							
								
								
									
										382
									
								
								models/wan/animate/face_blocks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										382
									
								
								models/wan/animate/face_blocks.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,382 @@
 | 
			
		||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
			
		||||
from torch import nn
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Tuple, Optional
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import math
 | 
			
		||||
from shared.attention import pay_attention
 | 
			
		||||
 | 
			
		||||
MEMORY_LAYOUT = {
 | 
			
		||||
    "flash": (
 | 
			
		||||
        lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
 | 
			
		||||
        lambda x: x,
 | 
			
		||||
    ),
 | 
			
		||||
    "torch": (
 | 
			
		||||
        lambda x: x.transpose(1, 2),
 | 
			
		||||
        lambda x: x.transpose(1, 2),
 | 
			
		||||
    ),
 | 
			
		||||
    "vanilla": (
 | 
			
		||||
        lambda x: x.transpose(1, 2),
 | 
			
		||||
        lambda x: x.transpose(1, 2),
 | 
			
		||||
    ),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def attention(
 | 
			
		||||
    q,
 | 
			
		||||
    k,
 | 
			
		||||
    v,
 | 
			
		||||
    mode="torch",
 | 
			
		||||
    drop_rate=0,
 | 
			
		||||
    attn_mask=None,
 | 
			
		||||
    causal=False,
 | 
			
		||||
    max_seqlen_q=None,
 | 
			
		||||
    batch_size=1,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Perform QKV self attention.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
 | 
			
		||||
        k (torch.Tensor): Key tensor with shape [b, s1, a, d]
 | 
			
		||||
        v (torch.Tensor): Value tensor with shape [b, s1, a, d]
 | 
			
		||||
        mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
 | 
			
		||||
        drop_rate (float): Dropout rate in attention map. (default: 0)
 | 
			
		||||
        attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
 | 
			
		||||
            (default: None)
 | 
			
		||||
        causal (bool): Whether to use causal attention. (default: False)
 | 
			
		||||
        cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
 | 
			
		||||
            used to index into q.
 | 
			
		||||
        cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
 | 
			
		||||
            used to index into kv.
 | 
			
		||||
        max_seqlen_q (int): The maximum sequence length in the batch of q.
 | 
			
		||||
        max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        torch.Tensor: Output tensor after self attention with shape [b, s, ad]
 | 
			
		||||
    """
 | 
			
		||||
    pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
 | 
			
		||||
 | 
			
		||||
    if mode == "torch":
 | 
			
		||||
        if attn_mask is not None and attn_mask.dtype != torch.bool:
 | 
			
		||||
            attn_mask = attn_mask.to(q.dtype)
 | 
			
		||||
        x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
 | 
			
		||||
 | 
			
		||||
    elif mode == "flash":
 | 
			
		||||
        x = flash_attn_func(
 | 
			
		||||
            q,
 | 
			
		||||
            k,
 | 
			
		||||
            v,
 | 
			
		||||
        )
 | 
			
		||||
        x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1])  # reshape x to [b, s, a, d]
 | 
			
		||||
    elif mode == "vanilla":
 | 
			
		||||
        scale_factor = 1 / math.sqrt(q.size(-1))
 | 
			
		||||
 | 
			
		||||
        b, a, s, _ = q.shape
 | 
			
		||||
        s1 = k.size(2)
 | 
			
		||||
        attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
 | 
			
		||||
        if causal:
 | 
			
		||||
            # Only applied to self attention
 | 
			
		||||
            assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
 | 
			
		||||
            temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
 | 
			
		||||
            attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
 | 
			
		||||
            attn_bias.to(q.dtype)
 | 
			
		||||
 | 
			
		||||
        if attn_mask is not None:
 | 
			
		||||
            if attn_mask.dtype == torch.bool:
 | 
			
		||||
                attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
 | 
			
		||||
            else:
 | 
			
		||||
                attn_bias += attn_mask
 | 
			
		||||
 | 
			
		||||
        attn = (q @ k.transpose(-2, -1)) * scale_factor
 | 
			
		||||
        attn += attn_bias
 | 
			
		||||
        attn = attn.softmax(dim=-1)
 | 
			
		||||
        attn = torch.dropout(attn, p=drop_rate, train=True)
 | 
			
		||||
        x = attn @ v
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError(f"Unsupported attention mode: {mode}")
 | 
			
		||||
 | 
			
		||||
    x = post_attn_layout(x)
 | 
			
		||||
    b, s, a, d = x.shape
 | 
			
		||||
    out = x.reshape(b, s, -1)
 | 
			
		||||
    return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CausalConv1d(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.pad_mode = pad_mode
 | 
			
		||||
        padding = (kernel_size - 1, 0)  # T
 | 
			
		||||
        self.time_causal_padding = padding
 | 
			
		||||
 | 
			
		||||
        self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
 | 
			
		||||
        return self.conv(x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FaceEncoder(nn.Module):
 | 
			
		||||
    def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
 | 
			
		||||
        factory_kwargs = {"dtype": dtype, "device": device}
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
 | 
			
		||||
        self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
 | 
			
		||||
        self.act = nn.SiLU()
 | 
			
		||||
        self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
 | 
			
		||||
        self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
 | 
			
		||||
 | 
			
		||||
        self.out_proj = nn.Linear(1024, hidden_dim)
 | 
			
		||||
        self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
 | 
			
		||||
 | 
			
		||||
        self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
 | 
			
		||||
 | 
			
		||||
        self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
 | 
			
		||||
 | 
			
		||||
        self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        
 | 
			
		||||
        x = rearrange(x, "b t c -> b c t")
 | 
			
		||||
        b, c, t = x.shape
 | 
			
		||||
 | 
			
		||||
        x = self.conv1_local(x)
 | 
			
		||||
        x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
 | 
			
		||||
        
 | 
			
		||||
        x = self.norm1(x)
 | 
			
		||||
        x = self.act(x)
 | 
			
		||||
        x = rearrange(x, "b t c -> b c t")
 | 
			
		||||
        x = self.conv2(x)
 | 
			
		||||
        x = rearrange(x, "b c t -> b t c")
 | 
			
		||||
        x = self.norm2(x)
 | 
			
		||||
        x = self.act(x)
 | 
			
		||||
        x = rearrange(x, "b t c -> b c t")
 | 
			
		||||
        x = self.conv3(x)
 | 
			
		||||
        x = rearrange(x, "b c t -> b t c")
 | 
			
		||||
        x = self.norm3(x)
 | 
			
		||||
        x = self.act(x)
 | 
			
		||||
        x = self.out_proj(x)
 | 
			
		||||
        x = rearrange(x, "(b n) t c -> b t n c", b=b)
 | 
			
		||||
        padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
 | 
			
		||||
        x = torch.cat([x, padding], dim=-2)
 | 
			
		||||
        x_local = x.clone()
 | 
			
		||||
 | 
			
		||||
        return x_local
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RMSNorm(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        dim: int,
 | 
			
		||||
        elementwise_affine=True,
 | 
			
		||||
        eps: float = 1e-6,
 | 
			
		||||
        device=None,
 | 
			
		||||
        dtype=None,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Initialize the RMSNorm normalization layer.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            dim (int): The dimension of the input tensor.
 | 
			
		||||
            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
 | 
			
		||||
 | 
			
		||||
        Attributes:
 | 
			
		||||
            eps (float): A small value added to the denominator for numerical stability.
 | 
			
		||||
            weight (nn.Parameter): Learnable scaling parameter.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        factory_kwargs = {"device": device, "dtype": dtype}
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
        if elementwise_affine:
 | 
			
		||||
            self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
 | 
			
		||||
 | 
			
		||||
    def _norm(self, x):
 | 
			
		||||
        """
 | 
			
		||||
        Apply the RMSNorm normalization to the input tensor.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            x (torch.Tensor): The input tensor.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            torch.Tensor: The normalized tensor.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        """
 | 
			
		||||
        Forward pass through the RMSNorm layer.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            x (torch.Tensor): The input tensor.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            torch.Tensor: The output tensor after applying RMSNorm.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        output = self._norm(x.float()).type_as(x)
 | 
			
		||||
        if hasattr(self, "weight"):
 | 
			
		||||
            output = output * self.weight
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_norm_layer(norm_layer):
 | 
			
		||||
    """
 | 
			
		||||
    Get the normalization layer.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        norm_layer (str): The type of normalization layer.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        norm_layer (nn.Module): The normalization layer.
 | 
			
		||||
    """
 | 
			
		||||
    if norm_layer == "layer":
 | 
			
		||||
        return nn.LayerNorm
 | 
			
		||||
    elif norm_layer == "rms":
 | 
			
		||||
        return RMSNorm
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FaceAdapter(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_dim: int,
 | 
			
		||||
        heads_num: int,
 | 
			
		||||
        qk_norm: bool = True,
 | 
			
		||||
        qk_norm_type: str = "rms",
 | 
			
		||||
        num_adapter_layers: int = 1,
 | 
			
		||||
        dtype=None,
 | 
			
		||||
        device=None,
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
        factory_kwargs = {"dtype": dtype, "device": device}
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = hidden_dim
 | 
			
		||||
        self.heads_num = heads_num
 | 
			
		||||
        self.fuser_blocks = nn.ModuleList(
 | 
			
		||||
            [
 | 
			
		||||
                FaceBlock(
 | 
			
		||||
                    self.hidden_size,
 | 
			
		||||
                    self.heads_num,
 | 
			
		||||
                    qk_norm=qk_norm,
 | 
			
		||||
                    qk_norm_type=qk_norm_type,
 | 
			
		||||
                    **factory_kwargs,
 | 
			
		||||
                )
 | 
			
		||||
                for _ in range(num_adapter_layers)
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        x: torch.Tensor,
 | 
			
		||||
        motion_embed: torch.Tensor,
 | 
			
		||||
        idx: int,
 | 
			
		||||
        freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
 | 
			
		||||
        freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
 | 
			
		||||
        return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FaceBlock(nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        heads_num: int,
 | 
			
		||||
        qk_norm: bool = True,
 | 
			
		||||
        qk_norm_type: str = "rms",
 | 
			
		||||
        qk_scale: float = None,
 | 
			
		||||
        dtype: Optional[torch.dtype] = None,
 | 
			
		||||
        device: Optional[torch.device] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        factory_kwargs = {"device": device, "dtype": dtype}
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.deterministic = False
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.heads_num = heads_num
 | 
			
		||||
        head_dim = hidden_size // heads_num
 | 
			
		||||
        self.scale = qk_scale or head_dim**-0.5
 | 
			
		||||
       
 | 
			
		||||
        self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
 | 
			
		||||
        self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
 | 
			
		||||
 | 
			
		||||
        self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
 | 
			
		||||
 | 
			
		||||
        qk_norm_layer = get_norm_layer(qk_norm_type)
 | 
			
		||||
        self.q_norm = (
 | 
			
		||||
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
 | 
			
		||||
        )
 | 
			
		||||
        self.k_norm = (
 | 
			
		||||
            qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
 | 
			
		||||
 | 
			
		||||
        self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        x: torch.Tensor,
 | 
			
		||||
        motion_vec: torch.Tensor,
 | 
			
		||||
        motion_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        use_context_parallel=False,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        
 | 
			
		||||
        B, T, N, C = motion_vec.shape
 | 
			
		||||
        T_comp = T
 | 
			
		||||
 | 
			
		||||
        x_motion = self.pre_norm_motion(motion_vec)
 | 
			
		||||
        x_feat = self.pre_norm_feat(x)
 | 
			
		||||
 | 
			
		||||
        kv = self.linear1_kv(x_motion)
 | 
			
		||||
        q = self.linear1_q(x_feat)
 | 
			
		||||
 | 
			
		||||
        k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
 | 
			
		||||
        q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
 | 
			
		||||
 | 
			
		||||
        # Apply QK-Norm if needed.
 | 
			
		||||
        q = self.q_norm(q).to(v)
 | 
			
		||||
        k = self.k_norm(k).to(v)
 | 
			
		||||
 | 
			
		||||
        k = rearrange(k, "B L N H D -> (B L) N H D")  
 | 
			
		||||
        v = rearrange(v, "B L N H D -> (B L) N H D") 
 | 
			
		||||
 | 
			
		||||
        if use_context_parallel:
 | 
			
		||||
            q = gather_forward(q, dim=1)
 | 
			
		||||
 | 
			
		||||
        q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)  
 | 
			
		||||
        # Compute attention.
 | 
			
		||||
    # Size([batches, tokens, heads, head_features])
 | 
			
		||||
        qkv_list = [q, k, v]
 | 
			
		||||
        del q,k,v
 | 
			
		||||
        attn = pay_attention(qkv_list)
 | 
			
		||||
        # attn = attention(
 | 
			
		||||
        #     q,
 | 
			
		||||
        #     k,
 | 
			
		||||
        #     v,
 | 
			
		||||
        #     max_seqlen_q=q.shape[1],
 | 
			
		||||
        #     batch_size=q.shape[0],
 | 
			
		||||
        # )
 | 
			
		||||
 | 
			
		||||
        attn = attn.reshape(*attn.shape[:2], -1)
 | 
			
		||||
        attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
 | 
			
		||||
        # if use_context_parallel:
 | 
			
		||||
        #     attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()]
 | 
			
		||||
 | 
			
		||||
        output = self.linear2(attn)
 | 
			
		||||
 | 
			
		||||
        if motion_mask is not None:
 | 
			
		||||
            output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
 | 
			
		||||
 | 
			
		||||
        return output
 | 
			
		||||
							
								
								
									
										31
									
								
								models/wan/animate/model_animate.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								models/wan/animate/model_animate.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,31 @@
 | 
			
		||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
			
		||||
import math
 | 
			
		||||
import types
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from einops import  rearrange
 | 
			
		||||
from typing import List
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.cuda.amp as amp
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
 | 
			
		||||
def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
 | 
			
		||||
    pose_latents = self.pose_patch_embedding(pose_latents)
 | 
			
		||||
    x[:, :, 1:] += pose_latents
 | 
			
		||||
    
 | 
			
		||||
    b,c,T,h,w = face_pixel_values.shape
 | 
			
		||||
    face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
 | 
			
		||||
    encode_bs = 8
 | 
			
		||||
    face_pixel_values_tmp = []
 | 
			
		||||
    for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
 | 
			
		||||
        face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
 | 
			
		||||
 | 
			
		||||
    motion_vec = torch.cat(face_pixel_values_tmp)
 | 
			
		||||
    
 | 
			
		||||
    motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
 | 
			
		||||
    motion_vec = self.face_encoder(motion_vec)
 | 
			
		||||
 | 
			
		||||
    B, L, H, C = motion_vec.shape
 | 
			
		||||
    pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
 | 
			
		||||
    motion_vec = torch.cat([pad_face, motion_vec], dim=1)
 | 
			
		||||
    return x, motion_vec
 | 
			
		||||
							
								
								
									
										308
									
								
								models/wan/animate/motion_encoder.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										308
									
								
								models/wan/animate/motion_encoder.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,308 @@
 | 
			
		||||
# Modified from ``https://github.com/wyhsirius/LIA``
 | 
			
		||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from torch.nn import functional as F
 | 
			
		||||
import math
 | 
			
		||||
 | 
			
		||||
def custom_qr(input_tensor):
 | 
			
		||||
    original_dtype = input_tensor.dtype
 | 
			
		||||
    if original_dtype in [torch.bfloat16, torch.float16]:
 | 
			
		||||
        q, r = torch.linalg.qr(input_tensor.to(torch.float32))
 | 
			
		||||
        return q.to(original_dtype), r.to(original_dtype)
 | 
			
		||||
    return torch.linalg.qr(input_tensor)
 | 
			
		||||
 | 
			
		||||
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
 | 
			
		||||
	return F.leaky_relu(input + bias, negative_slope) * scale
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
 | 
			
		||||
	_, minor, in_h, in_w = input.shape
 | 
			
		||||
	kernel_h, kernel_w = kernel.shape
 | 
			
		||||
 | 
			
		||||
	out = input.view(-1, minor, in_h, 1, in_w, 1)
 | 
			
		||||
	out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
 | 
			
		||||
	out = out.view(-1, minor, in_h * up_y, in_w * up_x)
 | 
			
		||||
 | 
			
		||||
	out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
 | 
			
		||||
	out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
 | 
			
		||||
		  max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
 | 
			
		||||
 | 
			
		||||
	out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
 | 
			
		||||
	w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
 | 
			
		||||
	out = F.conv2d(out, w)
 | 
			
		||||
	out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
 | 
			
		||||
					  in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
 | 
			
		||||
	return out[:, :, ::down_y, ::down_x]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
 | 
			
		||||
	return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_kernel(k):
 | 
			
		||||
	k = torch.tensor(k, dtype=torch.float32)
 | 
			
		||||
	if k.ndim == 1:
 | 
			
		||||
		k = k[None, :] * k[:, None]
 | 
			
		||||
	k /= k.sum()
 | 
			
		||||
	return k
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FusedLeakyReLU(nn.Module):
 | 
			
		||||
	def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
 | 
			
		||||
		super().__init__()
 | 
			
		||||
		self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
 | 
			
		||||
		self.negative_slope = negative_slope
 | 
			
		||||
		self.scale = scale
 | 
			
		||||
 | 
			
		||||
	def forward(self, input):
 | 
			
		||||
		out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
 | 
			
		||||
		return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Blur(nn.Module):
 | 
			
		||||
	def __init__(self, kernel, pad, upsample_factor=1):
 | 
			
		||||
		super().__init__()
 | 
			
		||||
 | 
			
		||||
		kernel = make_kernel(kernel)
 | 
			
		||||
 | 
			
		||||
		if upsample_factor > 1:
 | 
			
		||||
			kernel = kernel * (upsample_factor ** 2)
 | 
			
		||||
 | 
			
		||||
		self.register_buffer('kernel', kernel)
 | 
			
		||||
 | 
			
		||||
		self.pad = pad
 | 
			
		||||
 | 
			
		||||
	def forward(self, input):
 | 
			
		||||
		return upfirdn2d(input, self.kernel, pad=self.pad)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ScaledLeakyReLU(nn.Module):
 | 
			
		||||
	def __init__(self, negative_slope=0.2):
 | 
			
		||||
		super().__init__()
 | 
			
		||||
 | 
			
		||||
		self.negative_slope = negative_slope
 | 
			
		||||
 | 
			
		||||
	def forward(self, input):
 | 
			
		||||
		return F.leaky_relu(input, negative_slope=self.negative_slope)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EqualConv2d(nn.Module):
 | 
			
		||||
	def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
 | 
			
		||||
		super().__init__()
 | 
			
		||||
 | 
			
		||||
		self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
 | 
			
		||||
		self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
 | 
			
		||||
 | 
			
		||||
		self.stride = stride
 | 
			
		||||
		self.padding = padding
 | 
			
		||||
 | 
			
		||||
		if bias:
 | 
			
		||||
			self.bias = nn.Parameter(torch.zeros(out_channel))
 | 
			
		||||
		else:
 | 
			
		||||
			self.bias = None
 | 
			
		||||
 | 
			
		||||
	def forward(self, input):
 | 
			
		||||
 | 
			
		||||
		return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
 | 
			
		||||
 | 
			
		||||
	def __repr__(self):
 | 
			
		||||
		return (
 | 
			
		||||
			f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
 | 
			
		||||
			f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EqualLinear(nn.Module):
 | 
			
		||||
	def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
 | 
			
		||||
		super().__init__()
 | 
			
		||||
 | 
			
		||||
		self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
 | 
			
		||||
 | 
			
		||||
		if bias:
 | 
			
		||||
			self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
 | 
			
		||||
		else:
 | 
			
		||||
			self.bias = None
 | 
			
		||||
 | 
			
		||||
		self.activation = activation
 | 
			
		||||
 | 
			
		||||
		self.scale = (1 / math.sqrt(in_dim)) * lr_mul
 | 
			
		||||
		self.lr_mul = lr_mul
 | 
			
		||||
 | 
			
		||||
	def forward(self, input):
 | 
			
		||||
 | 
			
		||||
		if self.activation:
 | 
			
		||||
			out = F.linear(input, self.weight * self.scale)
 | 
			
		||||
			out = fused_leaky_relu(out, self.bias * self.lr_mul)
 | 
			
		||||
		else:
 | 
			
		||||
			out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
 | 
			
		||||
 | 
			
		||||
		return out
 | 
			
		||||
 | 
			
		||||
	def __repr__(self):
 | 
			
		||||
		return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ConvLayer(nn.Sequential):
 | 
			
		||||
	def __init__(
 | 
			
		||||
			self,
 | 
			
		||||
			in_channel,
 | 
			
		||||
			out_channel,
 | 
			
		||||
			kernel_size,
 | 
			
		||||
			downsample=False,
 | 
			
		||||
			blur_kernel=[1, 3, 3, 1],
 | 
			
		||||
			bias=True,
 | 
			
		||||
			activate=True,
 | 
			
		||||
	):
 | 
			
		||||
		layers = []
 | 
			
		||||
 | 
			
		||||
		if downsample:
 | 
			
		||||
			factor = 2
 | 
			
		||||
			p = (len(blur_kernel) - factor) + (kernel_size - 1)
 | 
			
		||||
			pad0 = (p + 1) // 2
 | 
			
		||||
			pad1 = p // 2
 | 
			
		||||
 | 
			
		||||
			layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
 | 
			
		||||
 | 
			
		||||
			stride = 2
 | 
			
		||||
			self.padding = 0
 | 
			
		||||
 | 
			
		||||
		else:
 | 
			
		||||
			stride = 1
 | 
			
		||||
			self.padding = kernel_size // 2
 | 
			
		||||
 | 
			
		||||
		layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
 | 
			
		||||
								  bias=bias and not activate))
 | 
			
		||||
 | 
			
		||||
		if activate:
 | 
			
		||||
			if bias:
 | 
			
		||||
				layers.append(FusedLeakyReLU(out_channel))
 | 
			
		||||
			else:
 | 
			
		||||
				layers.append(ScaledLeakyReLU(0.2))
 | 
			
		||||
 | 
			
		||||
		super().__init__(*layers)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ResBlock(nn.Module):
 | 
			
		||||
	def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
 | 
			
		||||
		super().__init__()
 | 
			
		||||
 | 
			
		||||
		self.conv1 = ConvLayer(in_channel, in_channel, 3)
 | 
			
		||||
		self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
 | 
			
		||||
 | 
			
		||||
		self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
 | 
			
		||||
 | 
			
		||||
	def forward(self, input):
 | 
			
		||||
		out = self.conv1(input)
 | 
			
		||||
		out = self.conv2(out)
 | 
			
		||||
 | 
			
		||||
		skip = self.skip(input)
 | 
			
		||||
		out = (out + skip) / math.sqrt(2)
 | 
			
		||||
 | 
			
		||||
		return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EncoderApp(nn.Module):
 | 
			
		||||
	def __init__(self, size, w_dim=512):
 | 
			
		||||
		super(EncoderApp, self).__init__()
 | 
			
		||||
 | 
			
		||||
		channels = {
 | 
			
		||||
			4: 512,
 | 
			
		||||
			8: 512,
 | 
			
		||||
			16: 512,
 | 
			
		||||
			32: 512,
 | 
			
		||||
			64: 256,
 | 
			
		||||
			128: 128,
 | 
			
		||||
			256: 64,
 | 
			
		||||
			512: 32,
 | 
			
		||||
			1024: 16
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		self.w_dim = w_dim
 | 
			
		||||
		log_size = int(math.log(size, 2))
 | 
			
		||||
 | 
			
		||||
		self.convs = nn.ModuleList()
 | 
			
		||||
		self.convs.append(ConvLayer(3, channels[size], 1))
 | 
			
		||||
 | 
			
		||||
		in_channel = channels[size]
 | 
			
		||||
		for i in range(log_size, 2, -1):
 | 
			
		||||
			out_channel = channels[2 ** (i - 1)]
 | 
			
		||||
			self.convs.append(ResBlock(in_channel, out_channel))
 | 
			
		||||
			in_channel = out_channel
 | 
			
		||||
 | 
			
		||||
		self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
 | 
			
		||||
 | 
			
		||||
	def forward(self, x):
 | 
			
		||||
 | 
			
		||||
		res = []
 | 
			
		||||
		h = x
 | 
			
		||||
		for conv in self.convs:
 | 
			
		||||
			h = conv(h)
 | 
			
		||||
			res.append(h)
 | 
			
		||||
 | 
			
		||||
		return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Encoder(nn.Module):
 | 
			
		||||
	def __init__(self, size, dim=512, dim_motion=20):
 | 
			
		||||
		super(Encoder, self).__init__()
 | 
			
		||||
 | 
			
		||||
		# appearance netmork
 | 
			
		||||
		self.net_app = EncoderApp(size, dim)
 | 
			
		||||
 | 
			
		||||
		# motion network
 | 
			
		||||
		fc = [EqualLinear(dim, dim)]
 | 
			
		||||
		for i in range(3):
 | 
			
		||||
			fc.append(EqualLinear(dim, dim))
 | 
			
		||||
 | 
			
		||||
		fc.append(EqualLinear(dim, dim_motion))
 | 
			
		||||
		self.fc = nn.Sequential(*fc)
 | 
			
		||||
 | 
			
		||||
	def enc_app(self, x):
 | 
			
		||||
		h_source = self.net_app(x)
 | 
			
		||||
		return h_source
 | 
			
		||||
 | 
			
		||||
	def enc_motion(self, x):
 | 
			
		||||
		h, _ = self.net_app(x)
 | 
			
		||||
		h_motion = self.fc(h)
 | 
			
		||||
		return h_motion
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Direction(nn.Module):
 | 
			
		||||
    def __init__(self, motion_dim):
 | 
			
		||||
        super(Direction, self).__init__()
 | 
			
		||||
        self.weight = nn.Parameter(torch.randn(512, motion_dim))
 | 
			
		||||
 | 
			
		||||
    def forward(self, input):
 | 
			
		||||
 | 
			
		||||
        weight = self.weight + 1e-8
 | 
			
		||||
        Q, R = custom_qr(weight)
 | 
			
		||||
        if input is None:
 | 
			
		||||
            return Q
 | 
			
		||||
        else:
 | 
			
		||||
            input_diag = torch.diag_embed(input)  # alpha, diagonal matrix
 | 
			
		||||
            out = torch.matmul(input_diag, Q.T)
 | 
			
		||||
            out = torch.sum(out, dim=1)
 | 
			
		||||
            return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Synthesis(nn.Module):
 | 
			
		||||
    def __init__(self, motion_dim):
 | 
			
		||||
        super(Synthesis, self).__init__()
 | 
			
		||||
        self.direction = Direction(motion_dim)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Generator(nn.Module):
 | 
			
		||||
	def __init__(self, size, style_dim=512, motion_dim=20):
 | 
			
		||||
		super().__init__()
 | 
			
		||||
 | 
			
		||||
		self.enc = Encoder(size, style_dim, motion_dim)
 | 
			
		||||
		self.dec = Synthesis(motion_dim)
 | 
			
		||||
 | 
			
		||||
	def get_motion(self, img):
 | 
			
		||||
		#motion_feat = self.enc.enc_motion(img)
 | 
			
		||||
		# motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
 | 
			
		||||
		with torch.cuda.amp.autocast(dtype=torch.float32):
 | 
			
		||||
			motion_feat = self.enc.enc_motion(img)
 | 
			
		||||
			motion = self.dec.direction(motion_feat)
 | 
			
		||||
		return motion
 | 
			
		||||
@ -203,10 +203,7 @@ class WanAny2V:
 | 
			
		||||
        self.use_timestep_transform = True 
 | 
			
		||||
 | 
			
		||||
    def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = None):
 | 
			
		||||
        if ref_images is None:
 | 
			
		||||
            ref_images = [None] * len(frames)
 | 
			
		||||
        else:
 | 
			
		||||
            assert len(frames) == len(ref_images)
 | 
			
		||||
        ref_images = [ref_images] * len(frames)
 | 
			
		||||
 | 
			
		||||
        if masks is None:
 | 
			
		||||
            latents = self.vae.encode(frames, tile_size = tile_size)
 | 
			
		||||
@ -238,11 +235,7 @@ class WanAny2V:
 | 
			
		||||
        return cat_latents
 | 
			
		||||
 | 
			
		||||
    def vace_encode_masks(self, masks, ref_images=None):
 | 
			
		||||
        if ref_images is None:
 | 
			
		||||
            ref_images = [None] * len(masks)
 | 
			
		||||
        else:
 | 
			
		||||
            assert len(masks) == len(ref_images)
 | 
			
		||||
 | 
			
		||||
        ref_images = [ref_images] * len(masks)
 | 
			
		||||
        result_masks = []
 | 
			
		||||
        for mask, refs in zip(masks, ref_images):
 | 
			
		||||
            c, depth, height, width = mask.shape
 | 
			
		||||
@ -270,79 +263,6 @@ class WanAny2V:
 | 
			
		||||
            result_masks.append(mask)
 | 
			
		||||
        return result_masks
 | 
			
		||||
 | 
			
		||||
    def vace_latent(self, z, m):
 | 
			
		||||
        return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size,  device, keep_video_guide_frames= [], pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
 | 
			
		||||
        image_sizes = []
 | 
			
		||||
        trim_video_guide = len(keep_video_guide_frames)
 | 
			
		||||
        def conv_tensor(t, device):
 | 
			
		||||
            return t.float().div_(127.5).add_(-1).permute(3, 0, 1, 2).to(device)
 | 
			
		||||
 | 
			
		||||
        for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
 | 
			
		||||
            prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
 | 
			
		||||
            num_frames = total_frames - prepend_count            
 | 
			
		||||
            num_frames = min(num_frames, trim_video_guide) if trim_video_guide > 0 and sub_src_video != None else num_frames
 | 
			
		||||
            if sub_src_mask is not None and sub_src_video is not None:
 | 
			
		||||
                src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
 | 
			
		||||
                src_mask[i] = conv_tensor(sub_src_mask[:num_frames], device)
 | 
			
		||||
                # src_video is [-1, 1] (at this function output), 0 = inpainting area (in fact 127  in [0, 255])
 | 
			
		||||
                # src_mask is [-1, 1] (at this function output), 0 = preserve original video (in fact 127  in [0, 255]) and 1 = Inpainting (in fact 255  in [0, 255])
 | 
			
		||||
                if prepend_count > 0:
 | 
			
		||||
                    src_video[i] =  torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
 | 
			
		||||
                    src_mask[i] =  torch.cat( [torch.full_like(sub_pre_src_video, -1.0), src_mask[i]] ,1)
 | 
			
		||||
                src_video_shape = src_video[i].shape
 | 
			
		||||
                if src_video_shape[1] != total_frames:
 | 
			
		||||
                    src_video[i] =  torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
 | 
			
		||||
                    src_mask[i] =  torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
 | 
			
		||||
                src_mask[i] = torch.clamp((src_mask[i][:, :, :, :] + 1) / 2, min=0, max=1)
 | 
			
		||||
                image_sizes.append(src_video[i].shape[2:])
 | 
			
		||||
            elif sub_src_video is None:
 | 
			
		||||
                if prepend_count > 0:
 | 
			
		||||
                    src_video[i] =  torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1)
 | 
			
		||||
                    src_mask[i] =  torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1)
 | 
			
		||||
                else:
 | 
			
		||||
                    src_video[i] = torch.zeros((3, total_frames, image_size[0], image_size[1]), device=device)
 | 
			
		||||
                    src_mask[i] = torch.ones_like(src_video[i], device=device)
 | 
			
		||||
                image_sizes.append(image_size)
 | 
			
		||||
            else:
 | 
			
		||||
                src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
 | 
			
		||||
                src_mask[i] = torch.ones_like(src_video[i], device=device)
 | 
			
		||||
                if prepend_count > 0:
 | 
			
		||||
                    src_video[i] =  torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
 | 
			
		||||
                    src_mask[i] =  torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
 | 
			
		||||
                src_video_shape = src_video[i].shape
 | 
			
		||||
                if src_video_shape[1] != total_frames:
 | 
			
		||||
                    src_video[i] =  torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
 | 
			
		||||
                    src_mask[i] =  torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
 | 
			
		||||
                image_sizes.append(src_video[i].shape[2:])
 | 
			
		||||
            for k, keep in enumerate(keep_video_guide_frames):
 | 
			
		||||
                if not keep:
 | 
			
		||||
                    pos = prepend_count + k
 | 
			
		||||
                    src_video[i][:, pos:pos+1] = 0
 | 
			
		||||
                    src_mask[i][:, pos:pos+1] = 1
 | 
			
		||||
 | 
			
		||||
            for k, frame in enumerate(inject_frames):
 | 
			
		||||
                if frame != None:
 | 
			
		||||
                    pos = prepend_count + k
 | 
			
		||||
                    src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        self.background_mask = None
 | 
			
		||||
        for i, ref_images in enumerate(src_ref_images):
 | 
			
		||||
            if ref_images is not None:
 | 
			
		||||
                image_size = image_sizes[i]
 | 
			
		||||
                for j, ref_img in enumerate(ref_images):
 | 
			
		||||
                    if ref_img is not None and not torch.is_tensor(ref_img):
 | 
			
		||||
                        if j==0 and any_background_ref:
 | 
			
		||||
                            if self.background_mask == None: self.background_mask = [None] * len(src_ref_images) 
 | 
			
		||||
                            src_ref_images[i][j], self.background_mask[i] = fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True)
 | 
			
		||||
                        else:
 | 
			
		||||
                            src_ref_images[i][j], _ = fit_image_into_canvas(ref_img, image_size, 1, device)
 | 
			
		||||
        if self.background_mask != None:
 | 
			
		||||
            self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref
 | 
			
		||||
        return src_video, src_mask, src_ref_images
 | 
			
		||||
 | 
			
		||||
    def get_vae_latents(self, ref_images, device, tile_size= 0):
 | 
			
		||||
        ref_vae_latents = []
 | 
			
		||||
@ -369,7 +289,9 @@ class WanAny2V:
 | 
			
		||||
    def generate(self,
 | 
			
		||||
        input_prompt,
 | 
			
		||||
        input_frames= None,
 | 
			
		||||
        input_frames2= None,
 | 
			
		||||
        input_masks = None,
 | 
			
		||||
        input_masks2 = None,
 | 
			
		||||
        input_ref_images = None,
 | 
			
		||||
        input_ref_masks = None,
 | 
			
		||||
        input_faces = None,
 | 
			
		||||
@ -615,21 +537,22 @@ class WanAny2V:
 | 
			
		||||
            pose_pixels = input_frames * input_masks
 | 
			
		||||
            input_masks = 1. - input_masks
 | 
			
		||||
            pose_pixels -= input_masks
 | 
			
		||||
            save_video(pose_pixels, "pose.mp4")
 | 
			
		||||
            pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0)
 | 
			
		||||
            input_frames = input_frames * input_masks
 | 
			
		||||
            if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames
 | 
			
		||||
            if prefix_frames_count > 0:
 | 
			
		||||
                 input_frames[:, :prefix_frames_count] = input_video 
 | 
			
		||||
                 input_masks[:, :prefix_frames_count] = 1 
 | 
			
		||||
            save_video(input_frames, "input_frames.mp4")
 | 
			
		||||
            save_video(input_masks, "input_masks.mp4", value_range=(0,1))
 | 
			
		||||
            # save_video(pose_pixels, "pose.mp4")
 | 
			
		||||
            # save_video(input_frames, "input_frames.mp4")
 | 
			
		||||
            # save_video(input_masks, "input_masks.mp4", value_range=(0,1))
 | 
			
		||||
            lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
 | 
			
		||||
            msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device) 
 | 
			
		||||
            msk_control =  self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device)
 | 
			
		||||
            msk = torch.concat([msk_ref, msk_control], dim=1)
 | 
			
		||||
            clip_image_start = image_ref = convert_image_to_tensor(input_ref_images[0]).to(self.device)
 | 
			
		||||
            lat_y = torch.concat(self.vae.encode([image_ref.unsqueeze(1).to(self.device), input_frames.to(self.device)], VAE_tile_size), dim=1)
 | 
			
		||||
            image_ref = input_ref_images[0].to(self.device)
 | 
			
		||||
            clip_image_start = image_ref.squeeze(1)
 | 
			
		||||
            lat_y = torch.concat(self.vae.encode([image_ref, input_frames.to(self.device)], VAE_tile_size), dim=1)
 | 
			
		||||
            y = torch.concat([msk, lat_y])
 | 
			
		||||
            kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)})
 | 
			
		||||
            lat_y = msk = msk_control = msk_ref = pose_pixels = None
 | 
			
		||||
@ -701,12 +624,11 @@ class WanAny2V:
 | 
			
		||||
 | 
			
		||||
        # Phantom
 | 
			
		||||
        if phantom:
 | 
			
		||||
            input_ref_images_neg = None
 | 
			
		||||
            if input_ref_images != None: # Phantom Ref images
 | 
			
		||||
                input_ref_images = self.get_vae_latents(input_ref_images, self.device)
 | 
			
		||||
                input_ref_images_neg = torch.zeros_like(input_ref_images)
 | 
			
		||||
                ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0
 | 
			
		||||
                trim_frames = input_ref_images.shape[1]
 | 
			
		||||
            lat_input_ref_images_neg = None
 | 
			
		||||
            if input_ref_images is not None: # Phantom Ref images
 | 
			
		||||
                lat_input_ref_images = self.get_vae_latents(input_ref_images, self.device)
 | 
			
		||||
                lat_input_ref_images_neg = torch.zeros_like(lat_input_ref_images)
 | 
			
		||||
                ref_images_count = trim_frames = lat_input_ref_images.shape[1]
 | 
			
		||||
 | 
			
		||||
        if ti2v:
 | 
			
		||||
            if input_video is None:
 | 
			
		||||
@ -721,25 +643,23 @@ class WanAny2V:
 | 
			
		||||
        # Vace
 | 
			
		||||
        if vace :
 | 
			
		||||
            # vace context encode
 | 
			
		||||
            input_frames = [u.to(self.device) for u in input_frames]
 | 
			
		||||
            input_ref_images = [ None if u == None else [v.to(self.device) for v in u]  for u in input_ref_images]
 | 
			
		||||
            input_masks = [u.to(self.device) for u in input_masks]
 | 
			
		||||
            input_frames = [input_frames.to(self.device)] +([] if input_frames2 is None else [input_frames2.to(self.device)])            
 | 
			
		||||
            input_masks = [input_masks.to(self.device)] + ([] if input_masks2 is None else [input_masks2.to(self.device)])         
 | 
			
		||||
            input_ref_images = None if input_ref_images is None else [ u.to(self.device) for u in input_ref_images]
 | 
			
		||||
            input_ref_masks = None if input_ref_masks is None else [ None if u is None else u.to(self.device) for u in input_ref_masks]
 | 
			
		||||
            ref_images_before = True
 | 
			
		||||
            if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask]
 | 
			
		||||
            z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
 | 
			
		||||
            m0 = self.vace_encode_masks(input_masks, input_ref_images)
 | 
			
		||||
            if self.background_mask != None:
 | 
			
		||||
                color_reference_frame = input_ref_images[0][0].clone()
 | 
			
		||||
                zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size )
 | 
			
		||||
                mbg = self.vace_encode_masks(self.background_mask, None)
 | 
			
		||||
            if input_ref_masks is not None and len(input_ref_masks) > 0 and input_ref_masks[0] is not None:
 | 
			
		||||
                color_reference_frame = input_ref_images[0].clone()
 | 
			
		||||
                zbg = self.vace_encode_frames( input_ref_images[:1] * len(input_frames), None, masks=input_ref_masks[0], tile_size = VAE_tile_size )
 | 
			
		||||
                mbg = self.vace_encode_masks(input_ref_masks[:1] * len(input_frames), None)
 | 
			
		||||
                for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg):
 | 
			
		||||
                    zz0[:, 0:1] = zzbg
 | 
			
		||||
                    mm0[:, 0:1] = mmbg
 | 
			
		||||
 | 
			
		||||
                self.background_mask = zz0 = mm0 = zzbg = mmbg = None
 | 
			
		||||
            z = self.vace_latent(z0, m0)
 | 
			
		||||
 | 
			
		||||
            ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0
 | 
			
		||||
                zz0 = mm0 = zzbg = mmbg = None
 | 
			
		||||
            z = [torch.cat([zz, mm], dim=0) for zz, mm in zip(z0, m0)]
 | 
			
		||||
            ref_images_count = len(input_ref_images) if input_ref_images is not None and input_ref_images is not None else 0
 | 
			
		||||
            context_scale = context_scale if context_scale != None else [1.0] * len(z)
 | 
			
		||||
            kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count })
 | 
			
		||||
            if overlapped_latents != None :
 | 
			
		||||
@ -747,15 +667,8 @@ class WanAny2V:
 | 
			
		||||
                extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0)
 | 
			
		||||
            if prefix_frames_count > 0:
 | 
			
		||||
                color_reference_frame = input_frames[0][:, prefix_frames_count -1:prefix_frames_count].clone()
 | 
			
		||||
 | 
			
		||||
            target_shape = list(z0[0].shape)
 | 
			
		||||
            target_shape[0] = int(target_shape[0] / 2)
 | 
			
		||||
            lat_h, lat_w = target_shape[-2:] 
 | 
			
		||||
            height = self.vae_stride[1] * lat_h
 | 
			
		||||
            width = self.vae_stride[2] * lat_w
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2])
 | 
			
		||||
        lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
 | 
			
		||||
        target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, lat_h, lat_w)
 | 
			
		||||
 | 
			
		||||
        if multitalk:
 | 
			
		||||
            if audio_proj is None:
 | 
			
		||||
@ -860,7 +773,9 @@ class WanAny2V:
 | 
			
		||||
            apg_norm_threshold = 55
 | 
			
		||||
            text_momentumbuffer  = MomentumBuffer(apg_momentum) 
 | 
			
		||||
            audio_momentumbuffer = MomentumBuffer(apg_momentum) 
 | 
			
		||||
 | 
			
		||||
        input_frames = input_frames2 = input_masks =input_masks2 = input_video = input_ref_images = input_ref_masks = pre_video_frame = None
 | 
			
		||||
        gc.collect()
 | 
			
		||||
        torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
        # denoising
 | 
			
		||||
        trans = self.model
 | 
			
		||||
@ -878,7 +793,7 @@ class WanAny2V:
 | 
			
		||||
            kwargs.update({"t": timestep, "current_step": start_step_no + i})  
 | 
			
		||||
            kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
 | 
			
		||||
 | 
			
		||||
            if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step:
 | 
			
		||||
            if denoising_strength < 1 and i <= injection_denoising_step:
 | 
			
		||||
                sigma = t / 1000
 | 
			
		||||
                noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
 | 
			
		||||
                if inject_from_start:
 | 
			
		||||
@ -912,8 +827,8 @@ class WanAny2V:
 | 
			
		||||
            any_guidance = guide_scale != 1
 | 
			
		||||
            if phantom:
 | 
			
		||||
                gen_args = {
 | 
			
		||||
                    "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 + 
 | 
			
		||||
                        [ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]),
 | 
			
		||||
                    "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 + 
 | 
			
		||||
                        [ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]),
 | 
			
		||||
                    "context": [context, context_null, context_null] ,
 | 
			
		||||
                }
 | 
			
		||||
            elif fantasy:
 | 
			
		||||
 | 
			
		||||
@ -21,6 +21,7 @@ class family_handler():
 | 
			
		||||
        extra_model_def["fps"] =fps
 | 
			
		||||
        extra_model_def["frames_minimum"] = 17
 | 
			
		||||
        extra_model_def["frames_steps"] = 20
 | 
			
		||||
        extra_model_def["latent_size"] = 4
 | 
			
		||||
        extra_model_def["sliding_window"] = True
 | 
			
		||||
        extra_model_def["skip_layer_guidance"] = True
 | 
			
		||||
        extra_model_def["tea_cache"] = True
 | 
			
		||||
 | 
			
		||||
@ -114,7 +114,6 @@ class family_handler():
 | 
			
		||||
        "tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels),
 | 
			
		||||
        "mag_cache" : True,
 | 
			
		||||
        "keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
 | 
			
		||||
        "convert_image_guide_to_video" : True,
 | 
			
		||||
        "sample_solvers":[
 | 
			
		||||
                            ("unipc", "unipc"),
 | 
			
		||||
                            ("euler", "euler"),
 | 
			
		||||
@ -175,6 +174,8 @@ class family_handler():
 | 
			
		||||
            extra_model_def["forced_guide_mask_inputs"] = True
 | 
			
		||||
            extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)"
 | 
			
		||||
            extra_model_def["background_ref_outpainted"] = False
 | 
			
		||||
            extra_model_def["return_image_refs_tensor"] = True
 | 
			
		||||
            extra_model_def["guide_inpaint_color"] = 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -196,15 +197,15 @@ class family_handler():
 | 
			
		||||
                    "letters_filter":  "KFI",
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            extra_model_def["lock_image_refs_ratios"] = True
 | 
			
		||||
            extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames"
 | 
			
		||||
            extra_model_def["video_guide_outpainting"] = [0,1]
 | 
			
		||||
            extra_model_def["pad_guide_video"] = True
 | 
			
		||||
            extra_model_def["guide_inpaint_color"] = 127.5
 | 
			
		||||
            extra_model_def["forced_guide_mask_inputs"] = True
 | 
			
		||||
            extra_model_def["return_image_refs_tensor"] = True
 | 
			
		||||
            
 | 
			
		||||
        if base_model_type in ["standin"]: 
 | 
			
		||||
            extra_model_def["lock_image_refs_ratios"] = True
 | 
			
		||||
            extra_model_def["fit_into_canvas_image_refs"] = 0
 | 
			
		||||
            extra_model_def["image_ref_choices"] = {
 | 
			
		||||
                "choices": [
 | 
			
		||||
                    ("No Reference Image", ""),
 | 
			
		||||
@ -480,6 +481,7 @@ class family_handler():
 | 
			
		||||
            ui_defaults.update({
 | 
			
		||||
                "video_prompt_type": "PVBXAKI", 
 | 
			
		||||
                "mask_expand": 20,
 | 
			
		||||
                "audio_prompt_type_value": "R",
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
        if text_oneframe_overlap(base_model_type):
 | 
			
		||||
 | 
			
		||||
@ -32,6 +32,14 @@ def seed_everything(seed: int):
 | 
			
		||||
    if torch.backends.mps.is_available():
 | 
			
		||||
        torch.mps.manual_seed(seed)
 | 
			
		||||
 | 
			
		||||
def has_video_file_extension(filename):
 | 
			
		||||
    extension = os.path.splitext(filename)[-1].lower()
 | 
			
		||||
    return extension in [".mp4"]
 | 
			
		||||
 | 
			
		||||
def has_image_file_extension(filename):
 | 
			
		||||
    extension = os.path.splitext(filename)[-1].lower()
 | 
			
		||||
    return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"]
 | 
			
		||||
 | 
			
		||||
def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ):
 | 
			
		||||
    import math
 | 
			
		||||
 | 
			
		||||
@ -94,7 +102,7 @@ def get_video_info(video_path):
 | 
			
		||||
    
 | 
			
		||||
    return fps, width, height, frame_count
 | 
			
		||||
 | 
			
		||||
def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, return_PIL = True) -> torch.Tensor:
 | 
			
		||||
def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, target_fps = None,  return_PIL = True) -> torch.Tensor:
 | 
			
		||||
    """Extract nth frame from video as PyTorch tensor normalized to [-1, 1]."""
 | 
			
		||||
    cap = cv2.VideoCapture(file_name)
 | 
			
		||||
    
 | 
			
		||||
@ -102,7 +110,10 @@ def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool
 | 
			
		||||
        raise ValueError(f"Cannot open video: {file_name}")
 | 
			
		||||
    
 | 
			
		||||
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 | 
			
		||||
    
 | 
			
		||||
    fps = round(cap.get(cv2.CAP_PROP_FPS))
 | 
			
		||||
    if target_fps is not None:
 | 
			
		||||
        frame_no = round(target_fps * frame_no /fps)
 | 
			
		||||
 | 
			
		||||
    # Handle out of bounds
 | 
			
		||||
    if frame_no >= total_frames or frame_no < 0:
 | 
			
		||||
        if return_last_if_missing:
 | 
			
		||||
@ -175,10 +186,15 @@ def remove_background(img, session=None):
 | 
			
		||||
def convert_image_to_tensor(image):
 | 
			
		||||
    return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0)
 | 
			
		||||
 | 
			
		||||
def convert_tensor_to_image(t, frame_no = 0):
 | 
			
		||||
def convert_tensor_to_image(t, frame_no = 0, mask_levels = False):
 | 
			
		||||
    if len(t.shape) == 4:
 | 
			
		||||
        t = t[:, frame_no] 
 | 
			
		||||
    return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
 | 
			
		||||
    if t.shape[0]== 1:
 | 
			
		||||
        t = t.expand(3,-1,-1)
 | 
			
		||||
    if mask_levels:
 | 
			
		||||
        return Image.fromarray(t.clone().mul_(255).permute(1,2,0).to(torch.uint8).cpu().numpy())
 | 
			
		||||
    else:
 | 
			
		||||
        return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
 | 
			
		||||
 | 
			
		||||
def save_image(tensor_image, name, frame_no = -1):
 | 
			
		||||
    convert_tensor_to_image(tensor_image, frame_no).save(name)
 | 
			
		||||
@ -257,7 +273,7 @@ def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fi
 | 
			
		||||
        image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
    return image, new_height, new_width
 | 
			
		||||
 | 
			
		||||
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5 ):
 | 
			
		||||
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5, return_tensor = False ):
 | 
			
		||||
    if rm_background:
 | 
			
		||||
        session = new_session() 
 | 
			
		||||
 | 
			
		||||
@ -266,7 +282,7 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
 | 
			
		||||
    for i, img in enumerate(img_list):
 | 
			
		||||
        width, height =  img.size 
 | 
			
		||||
        resized_mask = None
 | 
			
		||||
        if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2:
 | 
			
		||||
        if any_background_ref == 1 and i==0 or any_background_ref == 2:
 | 
			
		||||
            if outpainting_dims is not None and background_ref_outpainted:
 | 
			
		||||
                resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True)
 | 
			
		||||
            elif img.size != (budget_width, budget_height):
 | 
			
		||||
@ -291,7 +307,10 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
 | 
			
		||||
        if rm_background  and not (any_background_ref and i==0 or any_background_ref == 2) :
 | 
			
		||||
            # resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
 | 
			
		||||
            resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
 | 
			
		||||
        output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
 | 
			
		||||
        if return_tensor:
 | 
			
		||||
            output_list.append(convert_image_to_tensor(resized_image).unsqueeze(1)) 
 | 
			
		||||
        else:
 | 
			
		||||
            output_list.append(resized_image) 
 | 
			
		||||
        output_mask_list.append(resized_mask)
 | 
			
		||||
    return output_list, output_mask_list
 | 
			
		||||
 | 
			
		||||
@ -346,47 +365,46 @@ def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu
 | 
			
		||||
 | 
			
		||||
    return ref_img.to(device), canvas
 | 
			
		||||
 | 
			
		||||
def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, extract_guide_from_window_start = False, keep_video_guide_frames = [],  inject_frames = [], outpainting_dims = None ):
 | 
			
		||||
def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, keep_video_guide_frames = [],  inject_frames = [], outpainting_dims = None, device ="cpu"):
 | 
			
		||||
    src_videos, src_masks = [], []
 | 
			
		||||
    inpaint_color = guide_inpaint_color/127.5 - 1
 | 
			
		||||
    prepend_count = pre_video_guide.shape[1] if not extract_guide_from_window_start and pre_video_guide is not None else 0
 | 
			
		||||
    inpaint_color_compressed = guide_inpaint_color/127.5 - 1
 | 
			
		||||
    prepend_count = pre_video_guide.shape[1] if pre_video_guide is not None else 0
 | 
			
		||||
    for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)):
 | 
			
		||||
        src_video = src_mask = None
 | 
			
		||||
        if cur_video_guide is not None:
 | 
			
		||||
            src_video  = cur_video_guide.permute(3, 0, 1, 2).float().div_(127.5).sub_(1.) # c, f, h, w
 | 
			
		||||
        if cur_video_mask is not None and any_mask:
 | 
			
		||||
            src_mask  = cur_video_mask.permute(3, 0, 1, 2).float().div_(255)[0:1] # c, f, h, w
 | 
			
		||||
        if pre_video_guide is not None and not extract_guide_from_window_start:
 | 
			
		||||
        src_video, src_mask = cur_video_guide, cur_video_mask
 | 
			
		||||
        if pre_video_guide is not None:
 | 
			
		||||
            src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1)
 | 
			
		||||
            if any_mask:
 | 
			
		||||
                src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1)
 | 
			
		||||
        if src_video is None:
 | 
			
		||||
            if any_guide_padding:
 | 
			
		||||
                src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color, dtype = torch.float, device= "cpu")
 | 
			
		||||
                if any_mask:
 | 
			
		||||
                    src_mask = torch.zeros_like(src_video[0:1])
 | 
			
		||||
        elif src_video.shape[1] < current_video_length:
 | 
			
		||||
            if any_guide_padding:
 | 
			
		||||
                src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:]  ), inpaint_color, dtype = src_video.dtype, device= src_video.device) ], dim=1)
 | 
			
		||||
                if cur_video_mask is not None and any_mask:                                        
 | 
			
		||||
                    src_mask = torch.cat([src_mask, torch.full( (1, current_video_length - src_mask.shape[1], *src_mask.shape[-2:]  ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1)
 | 
			
		||||
 | 
			
		||||
        if any_guide_padding:
 | 
			
		||||
            if src_video is None:
 | 
			
		||||
                src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color_compressed, dtype = torch.float, device= device)
 | 
			
		||||
            elif src_video.shape[1] < current_video_length:
 | 
			
		||||
                src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:]  ), inpaint_color_compressed, dtype = src_video.dtype, device= src_video.device) ], dim=1)
 | 
			
		||||
        elif src_video is not None:
 | 
			
		||||
            new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1 
 | 
			
		||||
            src_video = src_video[:, :new_num_frames]
 | 
			
		||||
 | 
			
		||||
        if any_mask and src_video is not None:
 | 
			
		||||
            if src_mask is None:                   
 | 
			
		||||
                src_mask = torch.ones_like(src_video[:1])
 | 
			
		||||
            elif src_mask.shape[1] < src_video.shape[1]:
 | 
			
		||||
                src_mask = torch.cat([src_mask, torch.full( (1, src_video.shape[1]- src_mask.shape[1], *src_mask.shape[-2:]  ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1)
 | 
			
		||||
            else:
 | 
			
		||||
                new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1 
 | 
			
		||||
                src_video = src_video[:, :new_num_frames]
 | 
			
		||||
                if any_mask:
 | 
			
		||||
                    src_mask = src_mask[:, :new_num_frames]                                        
 | 
			
		||||
                src_mask = src_mask[:, :src_video.shape[1]]                                        
 | 
			
		||||
 | 
			
		||||
        for k, keep in enumerate(keep_video_guide_frames):
 | 
			
		||||
            if not keep:
 | 
			
		||||
                pos = prepend_count + k
 | 
			
		||||
                src_video[:, pos:pos+1] = inpaint_color
 | 
			
		||||
                src_mask[:, pos:pos+1] = 1
 | 
			
		||||
 | 
			
		||||
        for k, frame in enumerate(inject_frames):
 | 
			
		||||
            if frame != None:
 | 
			
		||||
                pos = prepend_count + k
 | 
			
		||||
                src_video[:, pos:pos+1], src_mask[:, pos:pos+1] = fit_image_into_canvas(frame, image_size, inpaint_color, device, True, outpainting_dims, return_mask= True)
 | 
			
		||||
        if src_video is not None :
 | 
			
		||||
            for k, keep in enumerate(keep_video_guide_frames):
 | 
			
		||||
                if not keep:
 | 
			
		||||
                    pos = prepend_count + k
 | 
			
		||||
                    src_video[:, pos:pos+1] = inpaint_color_compressed
 | 
			
		||||
                    if any_mask: src_mask[:, pos:pos+1] = 1
 | 
			
		||||
 | 
			
		||||
            for k, frame in enumerate(inject_frames):
 | 
			
		||||
                if frame != None:
 | 
			
		||||
                    pos = prepend_count + k
 | 
			
		||||
                    src_video[:, pos:pos+1], msk = fit_image_into_canvas(frame, image_size, guide_inpaint_color, device, True, outpainting_dims, return_mask= any_mask)
 | 
			
		||||
                    if any_mask: src_mask[:, pos:pos+1] = msk
 | 
			
		||||
        src_videos.append(src_video)
 | 
			
		||||
        src_masks.append(src_mask)
 | 
			
		||||
    return src_videos, src_masks
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										400
									
								
								wgp.py
									
									
									
									
									
								
							
							
						
						
									
										400
									
								
								wgp.py
									
									
									
									
									
								
							@ -24,6 +24,7 @@ from shared.utils import notification_sound
 | 
			
		||||
from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers
 | 
			
		||||
from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background, rgb_bw_to_rgba_mask
 | 
			
		||||
from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions
 | 
			
		||||
from shared.utils.utils import has_video_file_extension, has_image_file_extension
 | 
			
		||||
from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files,  save_video, save_image
 | 
			
		||||
from shared.utils.audio_video import save_image_metadata, read_image_metadata
 | 
			
		||||
from shared.match_archi import match_nvidia_architecture
 | 
			
		||||
@ -62,7 +63,7 @@ AUTOSAVE_FILENAME = "queue.zip"
 | 
			
		||||
PROMPT_VARS_MAX = 10
 | 
			
		||||
 | 
			
		||||
target_mmgp_version = "3.6.0"
 | 
			
		||||
WanGP_version = "8.61"
 | 
			
		||||
WanGP_version = "8.7"
 | 
			
		||||
settings_version = 2.35
 | 
			
		||||
max_source_video_frames = 3000
 | 
			
		||||
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
 | 
			
		||||
@ -1942,7 +1943,8 @@ def get_model_min_frames_and_step(model_type):
 | 
			
		||||
    mode_def = get_model_def(model_type)
 | 
			
		||||
    frames_minimum = mode_def.get("frames_minimum", 5)
 | 
			
		||||
    frames_steps = mode_def.get("frames_steps", 4)
 | 
			
		||||
    return frames_minimum, frames_steps 
 | 
			
		||||
    latent_size = mode_def.get("latent_size", frames_steps)
 | 
			
		||||
    return frames_minimum, frames_steps, latent_size 
 | 
			
		||||
    
 | 
			
		||||
def get_model_fps(model_type):
 | 
			
		||||
    mode_def = get_model_def(model_type)
 | 
			
		||||
@ -3459,7 +3461,7 @@ def select_video(state, input_file_list, event_data: gr.EventData):
 | 
			
		||||
            if len(video_other_prompts) >0 :
 | 
			
		||||
                values += [video_other_prompts]
 | 
			
		||||
                labels += ["Other Prompts"]
 | 
			
		||||
            if len(video_outpainting) >0 and any_letters(video_image_prompt_type, "VFK"):
 | 
			
		||||
            if len(video_outpainting) >0:
 | 
			
		||||
                values += [video_outpainting]
 | 
			
		||||
                labels += ["Outpainting"]
 | 
			
		||||
            video_sample_solver = configs.get("sample_solver", "")
 | 
			
		||||
@ -3532,6 +3534,11 @@ def convert_image(image):
 | 
			
		||||
    return cast(Image, ImageOps.exif_transpose(image))
 | 
			
		||||
 | 
			
		||||
def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='torch'):
 | 
			
		||||
    if isinstance(video_in, str) and has_image_file_extension(video_in):
 | 
			
		||||
        video_in = Image.open(video_in)
 | 
			
		||||
    if isinstance(video_in, Image.Image):
 | 
			
		||||
        return torch.from_numpy(np.array(video_in).astype(np.uint8)).unsqueeze(0)
 | 
			
		||||
    
 | 
			
		||||
    from shared.utils.utils import resample
 | 
			
		||||
 | 
			
		||||
    import decord
 | 
			
		||||
@ -3653,19 +3660,22 @@ def get_preprocessor(process_type, inpaint_color):
 | 
			
		||||
def process_images_multithread(image_processor, items, process_type, wrap_in_list = True, max_workers: int = os.cpu_count()/ 2) :
 | 
			
		||||
    if not items:
 | 
			
		||||
       return []    
 | 
			
		||||
    max_workers = 11
 | 
			
		||||
 | 
			
		||||
    import concurrent.futures
 | 
			
		||||
    start_time = time.time()
 | 
			
		||||
    # print(f"Preprocessus:{process_type} started")
 | 
			
		||||
    if process_type in ["prephase", "upsample"]: 
 | 
			
		||||
        if wrap_in_list :
 | 
			
		||||
            items = [ [img] for img in items]
 | 
			
		||||
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
 | 
			
		||||
            futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)}
 | 
			
		||||
            results = [None] * len(items)
 | 
			
		||||
            for future in concurrent.futures.as_completed(futures):
 | 
			
		||||
                idx = futures[future]
 | 
			
		||||
                results[idx] = future.result()
 | 
			
		||||
        if max_workers == 1:
 | 
			
		||||
            results = [image_processor(img) for img in items]
 | 
			
		||||
        else:
 | 
			
		||||
            with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
 | 
			
		||||
                futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)}
 | 
			
		||||
                results = [None] * len(items)
 | 
			
		||||
                for future in concurrent.futures.as_completed(futures):
 | 
			
		||||
                    idx = futures[future]
 | 
			
		||||
                    results[idx] = future.result()
 | 
			
		||||
 | 
			
		||||
        if wrap_in_list: 
 | 
			
		||||
            results = [ img[0] for img in results]
 | 
			
		||||
@ -3677,55 +3687,6 @@ def process_images_multithread(image_processor, items, process_type, wrap_in_lis
 | 
			
		||||
 | 
			
		||||
    return results
 | 
			
		||||
 | 
			
		||||
def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canvas = False, fit_crop = False, block_size= 16, expand_scale = 2, outpainting_dims = None, inpaint_color = 127):
 | 
			
		||||
    frame_width, frame_height = input_image.size
 | 
			
		||||
 | 
			
		||||
    if fit_crop:
 | 
			
		||||
        input_image = rescale_and_crop(input_image, width, height)
 | 
			
		||||
        if input_mask is not None:
 | 
			
		||||
            input_mask = rescale_and_crop(input_mask, width, height)
 | 
			
		||||
        return input_image, input_mask
 | 
			
		||||
 | 
			
		||||
    if outpainting_dims != None:
 | 
			
		||||
        if fit_canvas != None:
 | 
			
		||||
            frame_height, frame_width = get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims)
 | 
			
		||||
        else:
 | 
			
		||||
            frame_height, frame_width = height, width
 | 
			
		||||
 | 
			
		||||
    if fit_canvas != None:
 | 
			
		||||
        height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size)
 | 
			
		||||
 | 
			
		||||
    if outpainting_dims != None:
 | 
			
		||||
        final_height, final_width = height, width
 | 
			
		||||
        height, width, margin_top, margin_left =  get_outpainting_frame_location(final_height, final_width,  outpainting_dims, 1)        
 | 
			
		||||
 | 
			
		||||
    if fit_canvas != None or outpainting_dims != None:
 | 
			
		||||
        input_image = input_image.resize((width, height), resample=Image.Resampling.LANCZOS)
 | 
			
		||||
        if input_mask is not None:
 | 
			
		||||
            input_mask = input_mask.resize((width, height), resample=Image.Resampling.LANCZOS)
 | 
			
		||||
 | 
			
		||||
    if expand_scale != 0 and input_mask is not None:
 | 
			
		||||
        kernel_size = abs(expand_scale)
 | 
			
		||||
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
 | 
			
		||||
        op_expand = cv2.dilate if expand_scale > 0 else cv2.erode
 | 
			
		||||
        input_mask = np.array(input_mask) 
 | 
			
		||||
        input_mask = op_expand(input_mask, kernel, iterations=3)
 | 
			
		||||
        input_mask = Image.fromarray(input_mask)
 | 
			
		||||
 | 
			
		||||
    if outpainting_dims != None:
 | 
			
		||||
        inpaint_color = inpaint_color / 127.5-1
 | 
			
		||||
        image = convert_image_to_tensor(input_image)
 | 
			
		||||
        full_frame= torch.full( (image.shape[0], final_height, final_width),  inpaint_color, dtype= torch.float, device= image.device)
 | 
			
		||||
        full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = image
 | 
			
		||||
        input_image = convert_tensor_to_image(full_frame) 
 | 
			
		||||
 | 
			
		||||
        if input_mask is not None:
 | 
			
		||||
            mask = convert_image_to_tensor(input_mask)
 | 
			
		||||
            full_frame= torch.full( (mask.shape[0], final_height, final_width), 1, dtype= torch.float, device= mask.device)
 | 
			
		||||
            full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = mask
 | 
			
		||||
            input_mask = convert_tensor_to_image(full_frame) 
 | 
			
		||||
 | 
			
		||||
    return input_image, input_mask
 | 
			
		||||
def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512):
 | 
			
		||||
    if not input_video_path or max_frames <= 0:
 | 
			
		||||
        return None, None
 | 
			
		||||
@ -3780,6 +3741,8 @@ def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_fr
 | 
			
		||||
        save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None)
 | 
			
		||||
    return face_tensor
 | 
			
		||||
 | 
			
		||||
def get_default_workers():
 | 
			
		||||
    return os.cpu_count()/ 2
 | 
			
		||||
 | 
			
		||||
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,  max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1):
 | 
			
		||||
 | 
			
		||||
@ -3906,8 +3869,8 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
 | 
			
		||||
            return (target_frame, frame, mask) 
 | 
			
		||||
        else:
 | 
			
		||||
            return (target_frame, None, None)
 | 
			
		||||
 | 
			
		||||
    proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False)
 | 
			
		||||
    max_workers = get_default_workers()
 | 
			
		||||
    proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False, max_workers=max_workers)
 | 
			
		||||
    proc_list, proc_list_outside, proc_mask = [None] * len(proc_lists), [None] * len(proc_lists), [None] * len(proc_lists)
 | 
			
		||||
    for frame_idx, frame_group in enumerate(proc_lists): 
 | 
			
		||||
        proc_list[frame_idx], proc_list_outside[frame_idx], proc_mask[frame_idx] = frame_group
 | 
			
		||||
@ -3916,11 +3879,11 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
 | 
			
		||||
    mask_video = None
 | 
			
		||||
 | 
			
		||||
    if preproc2 != None:
 | 
			
		||||
        proc_list2 = process_images_multithread(preproc2, proc_list, process_type2)
 | 
			
		||||
        proc_list2 = process_images_multithread(preproc2, proc_list, process_type2, max_workers=max_workers)
 | 
			
		||||
        #### to be finished ...or not
 | 
			
		||||
    proc_list = process_images_multithread(preproc, proc_list, process_type)
 | 
			
		||||
    proc_list = process_images_multithread(preproc, proc_list, process_type, max_workers=max_workers)
 | 
			
		||||
    if any_mask:
 | 
			
		||||
        proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask)
 | 
			
		||||
        proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask, max_workers=max_workers)
 | 
			
		||||
    else:
 | 
			
		||||
        proc_list_outside = proc_mask = len(proc_list) * [None]
 | 
			
		||||
 | 
			
		||||
@ -3938,7 +3901,7 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
 | 
			
		||||
                full_frame= torch.full( (final_height, final_width, mask.shape[-1]), 255, dtype= torch.uint8, device= mask.device)
 | 
			
		||||
                full_frame[margin_top:margin_top+height, margin_left:margin_left+width] = mask
 | 
			
		||||
                mask = full_frame 
 | 
			
		||||
            masks.append(mask)
 | 
			
		||||
            masks.append(mask[:, :, 0:1].clone())
 | 
			
		||||
        else:
 | 
			
		||||
            masked_frame = processed_img
 | 
			
		||||
 | 
			
		||||
@ -3958,13 +3921,13 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
 | 
			
		||||
        proc_list[frame_no] = proc_list_outside[frame_no] = proc_mask[frame_no] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    if args.save_masks:
 | 
			
		||||
        from preprocessing.dwpose.pose import save_one_video
 | 
			
		||||
        saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ]
 | 
			
		||||
        save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None)
 | 
			
		||||
        if any_mask:
 | 
			
		||||
            saved_masks = [mask.cpu().numpy() for mask in masks ]
 | 
			
		||||
            save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None)
 | 
			
		||||
    # if args.save_masks:
 | 
			
		||||
    #     from preprocessing.dwpose.pose import save_one_video
 | 
			
		||||
    #     saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ]
 | 
			
		||||
    #     save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None)
 | 
			
		||||
    #     if any_mask:
 | 
			
		||||
    #         saved_masks = [mask.cpu().numpy() for mask in masks ]
 | 
			
		||||
    #         save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None)
 | 
			
		||||
    preproc = None
 | 
			
		||||
    preproc_outside = None
 | 
			
		||||
    gc.collect()
 | 
			
		||||
@ -3972,8 +3935,10 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
 | 
			
		||||
    if pad_frames > 0:
 | 
			
		||||
        masked_frames = masked_frames[0] * pad_frames + masked_frames
 | 
			
		||||
        if any_mask: masked_frames = masks[0] * pad_frames + masks
 | 
			
		||||
    masked_frames = torch.stack(masked_frames).permute(-1,0,1,2).float().div_(127.5).sub_(1.)
 | 
			
		||||
    masks = torch.stack(masks).permute(-1,0,1,2).float().div_(255) if any_mask else None
 | 
			
		||||
 | 
			
		||||
    return torch.stack(masked_frames), torch.stack(masks) if any_mask else None
 | 
			
		||||
    return masked_frames, masks
 | 
			
		||||
 | 
			
		||||
def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size = 16):
 | 
			
		||||
 | 
			
		||||
@ -4102,7 +4067,7 @@ def perform_spatial_upsampling(sample, spatial_upsampling):
 | 
			
		||||
    frames_to_upsample = [sample[:, i] for i in range( sample.shape[1]) ] 
 | 
			
		||||
    def upsample_frames(frame):
 | 
			
		||||
        return resize_lanczos(frame, h, w).unsqueeze(1)
 | 
			
		||||
    sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False), dim=1)
 | 
			
		||||
    sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False, max_workers=get_default_workers()), dim=1)
 | 
			
		||||
    frames_to_upsample = None
 | 
			
		||||
    return sample 
 | 
			
		||||
 | 
			
		||||
@ -4609,17 +4574,13 @@ def generate_video(
 | 
			
		||||
        batch_size = 1
 | 
			
		||||
    temp_filenames_list = []
 | 
			
		||||
 | 
			
		||||
    convert_image_guide_to_video = model_def.get("convert_image_guide_to_video", False)
 | 
			
		||||
    if convert_image_guide_to_video:
 | 
			
		||||
        if image_guide is not None and isinstance(image_guide, Image.Image):
 | 
			
		||||
            video_guide = convert_image_to_video(image_guide)
 | 
			
		||||
            temp_filenames_list.append(video_guide)
 | 
			
		||||
            image_guide = None
 | 
			
		||||
    if image_guide is not None and isinstance(image_guide, Image.Image):
 | 
			
		||||
        video_guide = image_guide
 | 
			
		||||
        image_guide = None
 | 
			
		||||
 | 
			
		||||
        if image_mask is not None and isinstance(image_mask, Image.Image):
 | 
			
		||||
            video_mask = convert_image_to_video(image_mask)
 | 
			
		||||
            temp_filenames_list.append(video_mask)
 | 
			
		||||
            image_mask = None
 | 
			
		||||
    if image_mask is not None and isinstance(image_mask, Image.Image):
 | 
			
		||||
        video_mask = image_mask
 | 
			
		||||
        image_mask = None
 | 
			
		||||
 | 
			
		||||
    if model_def.get("no_background_removal", False): remove_background_images_ref = 0
 | 
			
		||||
    
 | 
			
		||||
@ -4711,22 +4672,12 @@ def generate_video(
 | 
			
		||||
    device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
 | 
			
		||||
    guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5)
 | 
			
		||||
    extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False) 
 | 
			
		||||
    i2v = test_class_i2v(model_type)
 | 
			
		||||
    diffusion_forcing = "diffusion_forcing" in model_filename
 | 
			
		||||
    t2v = base_model_type in ["t2v"]
 | 
			
		||||
    ltxv = "ltxv" in model_filename
 | 
			
		||||
    vace =  test_vace_module(base_model_type) 
 | 
			
		||||
    hunyuan_t2v = "hunyuan_video_720" in model_filename
 | 
			
		||||
    hunyuan_i2v = "hunyuan_video_i2v" in model_filename
 | 
			
		||||
    hunyuan_custom = "hunyuan_video_custom" in model_filename
 | 
			
		||||
    hunyuan_custom_audio =  hunyuan_custom and "audio" in model_filename
 | 
			
		||||
    hunyuan_custom_edit =  hunyuan_custom and "edit" in model_filename
 | 
			
		||||
    hunyuan_avatar = "hunyuan_video_avatar" in model_filename
 | 
			
		||||
    fantasy = base_model_type in ["fantasy"]
 | 
			
		||||
    multitalk = model_def.get("multitalk_class", False)
 | 
			
		||||
    standin = model_def.get("standin_class", False)
 | 
			
		||||
    infinitetalk = base_model_type in ["infinitetalk"]
 | 
			
		||||
    animate =  base_model_type in ["animate"]
 | 
			
		||||
 | 
			
		||||
    if "B" in audio_prompt_type or "X" in audio_prompt_type:
 | 
			
		||||
        from models.wan.multitalk.multitalk import parse_speakers_locations
 | 
			
		||||
@ -4763,9 +4714,9 @@ def generate_video(
 | 
			
		||||
        sliding_window_size = current_video_length
 | 
			
		||||
        reuse_frames = 0
 | 
			
		||||
 | 
			
		||||
    _, latent_size = get_model_min_frames_and_step(model_type)  
 | 
			
		||||
    if diffusion_forcing: latent_size = 4
 | 
			
		||||
    _, _, latent_size = get_model_min_frames_and_step(model_type)  
 | 
			
		||||
    original_image_refs = image_refs
 | 
			
		||||
    image_refs = None if image_refs is None else [] + image_refs # work on a copy as it is going to be modified
 | 
			
		||||
    # image_refs = None
 | 
			
		||||
    # nb_frames_positions= 0
 | 
			
		||||
    # Output Video Ratio Priorities:
 | 
			
		||||
@ -4889,6 +4840,7 @@ def generate_video(
 | 
			
		||||
    initial_total_windows = 0
 | 
			
		||||
    discard_last_frames = sliding_window_discard_last_frames
 | 
			
		||||
    default_requested_frames_to_generate = current_video_length
 | 
			
		||||
    nb_frames_positions = 0
 | 
			
		||||
    if sliding_window:
 | 
			
		||||
        initial_total_windows= compute_sliding_window_no(default_requested_frames_to_generate, sliding_window_size, discard_last_frames, reuse_frames) 
 | 
			
		||||
        current_video_length = sliding_window_size
 | 
			
		||||
@ -4907,7 +4859,7 @@ def generate_video(
 | 
			
		||||
        if repeat_no >= total_generation: break
 | 
			
		||||
        repeat_no +=1
 | 
			
		||||
        gen["repeat_no"] = repeat_no
 | 
			
		||||
        src_video = src_mask = src_ref_images = new_image_guide = new_image_mask  = src_faces = None
 | 
			
		||||
        src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = None
 | 
			
		||||
        prefix_video = pre_video_frame = None
 | 
			
		||||
        source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window
 | 
			
		||||
        source_video_frames_count = 0  # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before )
 | 
			
		||||
@ -4963,7 +4915,6 @@ def generate_video(
 | 
			
		||||
                return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) )
 | 
			
		||||
            refresh_preview  = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {}
 | 
			
		||||
 | 
			
		||||
            src_ref_images, src_ref_masks  = image_refs, None
 | 
			
		||||
            image_start_tensor = image_end_tensor = None
 | 
			
		||||
            if window_no == 1 and (video_source is not None or image_start is not None):
 | 
			
		||||
                if image_start is not None:
 | 
			
		||||
@ -5020,7 +4971,7 @@ def generate_video(
 | 
			
		||||
                        if len(pos) > 0:
 | 
			
		||||
                            if pos in ["L", "l"]:
 | 
			
		||||
                                cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length 
 | 
			
		||||
                                if cur_end_pos >= last_frame_no and not joker_used:
 | 
			
		||||
                                if cur_end_pos >= last_frame_no-1 and not joker_used:
 | 
			
		||||
                                    joker_used = True
 | 
			
		||||
                                    cur_end_pos = last_frame_no -1
 | 
			
		||||
                                project_window_no += 1
 | 
			
		||||
@ -5036,141 +4987,53 @@ def generate_video(
 | 
			
		||||
                        frames_to_inject[pos] = image_refs[i] 
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
 | 
			
		||||
            if video_guide is not None:
 | 
			
		||||
                keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
 | 
			
		||||
                if len(error) > 0:
 | 
			
		||||
                    raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
 | 
			
		||||
                guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame
 | 
			
		||||
                extra_control_frames = model_def.get("extra_control_frames", 0)
 | 
			
		||||
                if extra_control_frames > 0 and aligned_guide_start_frame >= extra_control_frames: guide_frames_extract_start -= extra_control_frames
 | 
			
		||||
                        
 | 
			
		||||
                keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start  <0 else []
 | 
			
		||||
                keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ] 
 | 
			
		||||
                guide_frames_extract_count = len(keep_frames_parsed)
 | 
			
		||||
 | 
			
		||||
                # Extract Faces to video
 | 
			
		||||
                if "B" in video_prompt_type:
 | 
			
		||||
                    send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")])
 | 
			
		||||
                    src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps)
 | 
			
		||||
                    if src_faces is not None and src_faces.shape[1] < current_video_length:
 | 
			
		||||
                        src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1)
 | 
			
		||||
 | 
			
		||||
                if vace or animate:
 | 
			
		||||
                    video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
 | 
			
		||||
                    context_scale = [ control_net_weight]
 | 
			
		||||
                    if "V" in video_prompt_type:
 | 
			
		||||
                        process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
 | 
			
		||||
                        preprocess_type, preprocess_type2 =  "raw", None 
 | 
			
		||||
                        for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")):
 | 
			
		||||
                            if process_num == 0:
 | 
			
		||||
                                preprocess_type = process_map_video_guide.get(process_letter, "raw")
 | 
			
		||||
                            else:
 | 
			
		||||
                                preprocess_type2 = process_map_video_guide.get(process_letter, None)
 | 
			
		||||
                        status_info = "Extracting " + processes_names[preprocess_type]
 | 
			
		||||
                        extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask])
 | 
			
		||||
                        if len(extra_process_list) == 1:
 | 
			
		||||
                            status_info += " and " + processes_names[extra_process_list[0]]
 | 
			
		||||
                        elif len(extra_process_list) == 2:
 | 
			
		||||
                            status_info +=  ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]]
 | 
			
		||||
                        if preprocess_type2 is not None:
 | 
			
		||||
                            context_scale = [ control_net_weight /2, control_net_weight2 /2]
 | 
			
		||||
                        send_cmd("progress", [0, get_latest_status(state, status_info)])
 | 
			
		||||
                        inpaint_color = 0 if preprocess_type=="pose" else guide_inpaint_color
 | 
			
		||||
                        video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps,  process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color )
 | 
			
		||||
                        if preprocess_type2 != None:
 | 
			
		||||
                            video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps,  process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 )
 | 
			
		||||
                # Sparse Video to Video
 | 
			
		||||
                sparse_video_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, target_fps = fps, return_PIL = True) if "R" in video_prompt_type else None
 | 
			
		||||
 | 
			
		||||
                        if video_guide_processed != None:
 | 
			
		||||
                            if sample_fit_canvas != None:
 | 
			
		||||
                                image_size = video_guide_processed.shape[-3: -1]
 | 
			
		||||
                                sample_fit_canvas = None
 | 
			
		||||
                            refresh_preview["video_guide"] = Image.fromarray(video_guide_processed[0].cpu().numpy())
 | 
			
		||||
                            if video_guide_processed2 != None:
 | 
			
		||||
                                refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())] 
 | 
			
		||||
                            if video_mask_processed != None:                        
 | 
			
		||||
                                refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy())
 | 
			
		||||
 | 
			
		||||
                    frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame]
 | 
			
		||||
 | 
			
		||||
                    if not vace and (any_letters(video_prompt_type ,"FV") or model_def.get("forced_guide_mask_inputs", False)): 
 | 
			
		||||
                        any_mask = True
 | 
			
		||||
                        any_guide_padding = model_def.get("pad_guide_video", False)
 | 
			
		||||
                        from shared.utils.utils import prepare_video_guide_and_mask
 | 
			
		||||
                        src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed, video_guide_processed2], 
 | 
			
		||||
                                                                                [video_mask_processed, video_mask_processed2], 
 | 
			
		||||
                                                                                pre_video_guide, image_size, current_video_length, latent_size,
 | 
			
		||||
                                                                                any_mask, any_guide_padding, guide_inpaint_color, extract_guide_from_window_start,
 | 
			
		||||
                                                                                keep_frames_parsed, frames_to_inject_parsed , outpainting_dims)
 | 
			
		||||
 | 
			
		||||
                        src_video, src_video2 = src_videos 
 | 
			
		||||
                        src_mask, src_mask2 = src_masks 
 | 
			
		||||
                        if src_video is None:
 | 
			
		||||
                            abort = True 
 | 
			
		||||
                            break
 | 
			
		||||
                        if src_faces is not None:
 | 
			
		||||
                            if src_faces.shape[1] < src_video.shape[1]:
 | 
			
		||||
                                src_faces = torch.concat( [src_faces,  src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1)
 | 
			
		||||
                            else:
 | 
			
		||||
                                src_faces = src_faces[:, :src_video.shape[1]]
 | 
			
		||||
                        if args.save_masks:
 | 
			
		||||
                            save_video( src_video, "masked_frames.mp4", fps)
 | 
			
		||||
                            if src_video2 is not None:
 | 
			
		||||
                                save_video( src_video2, "masked_frames2.mp4", fps)
 | 
			
		||||
                            if any_mask:
 | 
			
		||||
                                save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
 | 
			
		||||
 | 
			
		||||
                elif ltxv:          
 | 
			
		||||
                    preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
 | 
			
		||||
                    status_info = "Extracting " + processes_names[preprocess_type]
 | 
			
		||||
                    send_cmd("progress", [0, get_latest_status(state, status_info)])
 | 
			
		||||
                    # start one frame ealier to facilitate latents merging later
 | 
			
		||||
                    src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps,  process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size )
 | 
			
		||||
                    if src_video !=  None:
 | 
			
		||||
                        src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ]
 | 
			
		||||
                        refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
 | 
			
		||||
                        refresh_preview["video_mask"] = None
 | 
			
		||||
                        src_video  = src_video.permute(3, 0, 1, 2)
 | 
			
		||||
                        src_video  = src_video.float().div_(127.5).sub_(1.) # c, f, h, w
 | 
			
		||||
                        if sample_fit_canvas != None:
 | 
			
		||||
                            image_size = src_video.shape[-2:]
 | 
			
		||||
                            sample_fit_canvas = None
 | 
			
		||||
 | 
			
		||||
                elif hunyuan_custom_edit:
 | 
			
		||||
                    if "P" in  video_prompt_type:
 | 
			
		||||
                        progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")]
 | 
			
		||||
                # Generic Video Preprocessing
 | 
			
		||||
                process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
 | 
			
		||||
                preprocess_type, preprocess_type2 =  "raw", None 
 | 
			
		||||
                for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")):
 | 
			
		||||
                    if process_num == 0:
 | 
			
		||||
                        preprocess_type = process_map_video_guide.get(process_letter, "raw")
 | 
			
		||||
                    else:
 | 
			
		||||
                        progress_args = [0, get_latest_status(state,"Extracting Video and Mask")]
 | 
			
		||||
                        preprocess_type2 = process_map_video_guide.get(process_letter, None)
 | 
			
		||||
                status_info = "Extracting " + processes_names[preprocess_type]
 | 
			
		||||
                extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask])
 | 
			
		||||
                if len(extra_process_list) == 1:
 | 
			
		||||
                    status_info += " and " + processes_names[extra_process_list[0]]
 | 
			
		||||
                elif len(extra_process_list) == 2:
 | 
			
		||||
                    status_info +=  ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]]
 | 
			
		||||
                context_scale = [control_net_weight /2, control_net_weight2 /2] if preprocess_type2 is not None else [control_net_weight]
 | 
			
		||||
                if not (preprocess_type == "identity" and preprocess_type2 is None and video_mask is None):send_cmd("progress", [0, get_latest_status(state, status_info)])
 | 
			
		||||
                inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask == "inpaint" else guide_inpaint_color
 | 
			
		||||
                video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide if sparse_video_image is None else sparse_video_image, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps,  process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color, block_size = block_size )
 | 
			
		||||
                if preprocess_type2 != None:
 | 
			
		||||
                    video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps,  process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2, block_size = block_size )
 | 
			
		||||
 | 
			
		||||
                    send_cmd("progress", progress_args)
 | 
			
		||||
                    src_video, src_mask = preprocess_video_with_mask(video_guide,  video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0)
 | 
			
		||||
                    refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) 
 | 
			
		||||
                    if src_mask != None:                        
 | 
			
		||||
                        refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy())
 | 
			
		||||
 | 
			
		||||
                elif "R" in video_prompt_type: # sparse video to video
 | 
			
		||||
                    src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True)
 | 
			
		||||
                    src_image, _, _ = calculate_dimensions_and_resize_image(src_image, image_size[0], image_size[1  ], sample_fit_canvas, fit_crop, block_size = block_size)
 | 
			
		||||
                    refresh_preview["video_guide"] = src_image  
 | 
			
		||||
                    src_video = convert_image_to_tensor(src_image).unsqueeze(1)
 | 
			
		||||
                    if sample_fit_canvas != None:  
 | 
			
		||||
                        image_size  = src_video.shape[-2:]
 | 
			
		||||
                        sample_fit_canvas = None
 | 
			
		||||
 | 
			
		||||
                else: # video to video
 | 
			
		||||
                    video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size= block_size)
 | 
			
		||||
                    if video_guide_processed is None:
 | 
			
		||||
                        src_video = pre_video_guide
 | 
			
		||||
                    else:
 | 
			
		||||
                        if sample_fit_canvas != None:
 | 
			
		||||
                            image_size = video_guide_processed.shape[-3: -1]
 | 
			
		||||
                            sample_fit_canvas = None
 | 
			
		||||
                        src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2)
 | 
			
		||||
                        if pre_video_guide != None:
 | 
			
		||||
                            src_video = torch.cat( [pre_video_guide, src_video], dim=1) 
 | 
			
		||||
            elif image_guide is not None:
 | 
			
		||||
                new_image_guide, new_image_mask = preprocess_image_with_mask(image_guide, image_mask, image_size[0], image_size[1], fit_canvas = sample_fit_canvas, fit_crop= fit_crop, block_size= block_size, expand_scale = mask_expand, outpainting_dims=outpainting_dims)
 | 
			
		||||
                if sample_fit_canvas is not None:
 | 
			
		||||
                    image_size = (new_image_guide.size[1], new_image_guide.size[0])
 | 
			
		||||
                if video_guide_processed is not None  and sample_fit_canvas is not None:
 | 
			
		||||
                    image_size = video_guide_processed.shape[-2:]
 | 
			
		||||
                    sample_fit_canvas = None
 | 
			
		||||
                refresh_preview["image_guide"] = new_image_guide
 | 
			
		||||
                if new_image_mask is not None:
 | 
			
		||||
                    refresh_preview["image_mask"] = new_image_mask
 | 
			
		||||
 | 
			
		||||
            if window_no == 1 and image_refs is not None and len(image_refs) > 0:
 | 
			
		||||
                if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) :
 | 
			
		||||
@ -5192,45 +5055,68 @@ def generate_video(
 | 
			
		||||
                            image_refs[i] = rescale_and_crop(img, default_image_size[1], default_image_size[0])
 | 
			
		||||
                        refresh_preview["image_refs"] = image_refs
 | 
			
		||||
 | 
			
		||||
                    if len(image_refs) > nb_frames_positions:  
 | 
			
		||||
                    if len(image_refs) > nb_frames_positions:
 | 
			
		||||
                        src_ref_images = image_refs[nb_frames_positions:]
 | 
			
		||||
                        if remove_background_images_ref > 0:
 | 
			
		||||
                            send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
 | 
			
		||||
                        # keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested
 | 
			
		||||
                        image_refs[nb_frames_positions:], src_ref_masks  = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0],
 | 
			
		||||
 | 
			
		||||
                        src_ref_images, src_ref_masks  = resize_and_remove_background(src_ref_images , image_size[1], image_size[0],
 | 
			
		||||
                                                                                        remove_background_images_ref > 0, any_background_ref, 
 | 
			
		||||
                                                                                        fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1,
 | 
			
		||||
                                                                                        fit_into_canvas= model_def.get("fit_into_canvas_image_refs", 1),
 | 
			
		||||
                                                                                        block_size=block_size,
 | 
			
		||||
                                                                                        outpainting_dims =outpainting_dims,
 | 
			
		||||
                                                                                        background_ref_outpainted = model_def.get("background_ref_outpainted", True) )
 | 
			
		||||
                        refresh_preview["image_refs"] = image_refs
 | 
			
		||||
                                                                                        background_ref_outpainted = model_def.get("background_ref_outpainted", True),
 | 
			
		||||
                                                                                        return_tensor= model_def.get("return_image_refs_tensor", False) )
 | 
			
		||||
 
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            if vace :
 | 
			
		||||
                image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
 | 
			
		||||
 | 
			
		||||
                src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2],
 | 
			
		||||
                                                                        [video_mask_processed] if video_guide_processed2 == None else [video_mask_processed, video_mask_processed2],
 | 
			
		||||
                                                                        [image_refs_copy] if video_guide_processed2 == None else [image_refs_copy, image_refs_copy], 
 | 
			
		||||
                                                                        current_video_length, image_size = image_size, device ="cpu",
 | 
			
		||||
                                                                        keep_video_guide_frames=keep_frames_parsed,
 | 
			
		||||
                                                                        pre_src_video = [pre_video_guide] if video_guide_processed2 == None else [pre_video_guide, pre_video_guide],
 | 
			
		||||
                                                                        inject_frames= frames_to_inject_parsed,
 | 
			
		||||
                                                                        outpainting_dims = outpainting_dims,
 | 
			
		||||
                                                                        any_background_ref = any_background_ref
 | 
			
		||||
                                                                        )
 | 
			
		||||
                if len(frames_to_inject_parsed) or any_background_ref:
 | 
			
		||||
                    new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + 0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
 | 
			
		||||
                    if any_background_ref:
 | 
			
		||||
                        new_image_refs +=  [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:]
 | 
			
		||||
            frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame]
 | 
			
		||||
            if video_guide is not None or len(frames_to_inject_parsed) > 0 or model_def.get("forced_guide_mask_inputs", False): 
 | 
			
		||||
                any_mask = video_mask is not None or model_def.get("forced_guide_mask_inputs", False)
 | 
			
		||||
                any_guide_padding = model_def.get("pad_guide_video", False)
 | 
			
		||||
                from shared.utils.utils import prepare_video_guide_and_mask
 | 
			
		||||
                src_videos, src_masks = prepare_video_guide_and_mask(   [video_guide_processed] + ([] if video_guide_processed2 is None else [video_guide_processed2]), 
 | 
			
		||||
                                                                        [video_mask_processed] + ([] if video_mask_processed2 is None else [video_mask_processed2]),
 | 
			
		||||
                                                                        None if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else pre_video_guide, 
 | 
			
		||||
                                                                        image_size, current_video_length, latent_size,
 | 
			
		||||
                                                                        any_mask, any_guide_padding, guide_inpaint_color, 
 | 
			
		||||
                                                                        keep_frames_parsed, frames_to_inject_parsed , outpainting_dims)
 | 
			
		||||
                video_guide_processed = video_guide_processed2 = video_mask_processed = video_mask_processed2 = None
 | 
			
		||||
                if len(src_videos) == 1:
 | 
			
		||||
                    src_video, src_video2, src_mask, src_mask2 = src_videos[0], None, src_masks[0], None 
 | 
			
		||||
                else:
 | 
			
		||||
                    src_video, src_video2 = src_videos 
 | 
			
		||||
                    src_mask, src_mask2 = src_masks 
 | 
			
		||||
                src_videos = src_masks = None
 | 
			
		||||
                if src_video is None:
 | 
			
		||||
                    abort = True 
 | 
			
		||||
                    break
 | 
			
		||||
                if src_faces is not None:
 | 
			
		||||
                    if src_faces.shape[1] < src_video.shape[1]:
 | 
			
		||||
                        src_faces = torch.concat( [src_faces,  src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1)
 | 
			
		||||
                    else:
 | 
			
		||||
                        new_image_refs +=  image_refs[nb_frames_positions:]
 | 
			
		||||
                    refresh_preview["image_refs"] = new_image_refs
 | 
			
		||||
                    new_image_refs = None
 | 
			
		||||
 | 
			
		||||
                if sample_fit_canvas != None:
 | 
			
		||||
                    image_size = src_video[0].shape[-2:]
 | 
			
		||||
                    sample_fit_canvas = None
 | 
			
		||||
                        src_faces = src_faces[:, :src_video.shape[1]]
 | 
			
		||||
                if video_guide is not None or len(frames_to_inject_parsed) > 0:
 | 
			
		||||
                    if args.save_masks:
 | 
			
		||||
                        if src_video is not None: save_video( src_video, "masked_frames.mp4", fps)
 | 
			
		||||
                        if src_video2 is not None: save_video( src_video2, "masked_frames2.mp4", fps)
 | 
			
		||||
                        if any_mask: save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
 | 
			
		||||
                if video_guide is not None:                        
 | 
			
		||||
                    preview_frame_no = 0 if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else (guide_start_frame - window_start_frame) 
 | 
			
		||||
                    refresh_preview["video_guide"] = convert_tensor_to_image(src_video, preview_frame_no)
 | 
			
		||||
                    if src_video2 is not None:
 | 
			
		||||
                        refresh_preview["video_guide"] = [refresh_preview["video_guide"], convert_tensor_to_image(src_video2, preview_frame_no)] 
 | 
			
		||||
                    if src_mask is not None and video_mask is not None:                        
 | 
			
		||||
                        refresh_preview["video_mask"] = convert_tensor_to_image(src_mask, preview_frame_no, mask_levels = True)
 | 
			
		||||
 | 
			
		||||
            if src_ref_images is not None or nb_frames_positions:
 | 
			
		||||
                if len(frames_to_inject_parsed):
 | 
			
		||||
                    new_image_refs = [convert_tensor_to_image(src_video, frame_no + (0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame)) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
 | 
			
		||||
                else:
 | 
			
		||||
                    new_image_refs = []
 | 
			
		||||
                if src_ref_images is not None:
 | 
			
		||||
                    new_image_refs +=  [convert_tensor_to_image(img) if torch.is_tensor(img) else img for img in src_ref_images  ]
 | 
			
		||||
                refresh_preview["image_refs"] = new_image_refs
 | 
			
		||||
                new_image_refs = None
 | 
			
		||||
 | 
			
		||||
            if len(refresh_preview) > 0:
 | 
			
		||||
                new_inputs= locals()
 | 
			
		||||
@ -5339,8 +5225,6 @@ def generate_video(
 | 
			
		||||
                    pre_video_frame = pre_video_frame,
 | 
			
		||||
                    original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [],
 | 
			
		||||
                    image_refs_relative_size = image_refs_relative_size,
 | 
			
		||||
                    image_guide= new_image_guide,
 | 
			
		||||
                    image_mask= new_image_mask,
 | 
			
		||||
                    outpainting_dims = outpainting_dims,
 | 
			
		||||
                )
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
@ -6320,8 +6204,11 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
 | 
			
		||||
            pop += ["image_refs_relative_size"]
 | 
			
		||||
 | 
			
		||||
    if not vace:
 | 
			
		||||
        pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2"] 
 | 
			
		||||
        pop += ["frames_positions", "control_net_weight", "control_net_weight2"] 
 | 
			
		||||
                
 | 
			
		||||
    if model_def.get("video_guide_outpainting", None) is None:
 | 
			
		||||
        pop += ["video_guide_outpainting"] 
 | 
			
		||||
 | 
			
		||||
    if not (vace or t2v):
 | 
			
		||||
        pop += ["min_frames_if_references"]
 | 
			
		||||
 | 
			
		||||
@ -6506,13 +6393,6 @@ def eject_video_from_gallery(state, input_file_list, choice):
 | 
			
		||||
        choice = min(choice, len(file_list))
 | 
			
		||||
    return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0)
 | 
			
		||||
 | 
			
		||||
def has_video_file_extension(filename):
 | 
			
		||||
    extension = os.path.splitext(filename)[-1].lower()
 | 
			
		||||
    return extension in [".mp4"]
 | 
			
		||||
 | 
			
		||||
def has_image_file_extension(filename):
 | 
			
		||||
    extension = os.path.splitext(filename)[-1].lower()
 | 
			
		||||
    return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"]
 | 
			
		||||
def add_videos_to_gallery(state, input_file_list, choice, files_to_load):
 | 
			
		||||
    gen = get_gen_info(state)
 | 
			
		||||
    if files_to_load == None:
 | 
			
		||||
@ -7881,7 +7761,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
                elif recammaster:
 | 
			
		||||
                    video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", get_max_frames(81)), step=4, label="Number of frames (16 = 1s), locked", interactive= False, visible = True)
 | 
			
		||||
                else:
 | 
			
		||||
                    min_frames, frames_step = get_model_min_frames_and_step(base_model_type)
 | 
			
		||||
                    min_frames, frames_step, _ = get_model_min_frames_and_step(base_model_type)
 | 
			
		||||
                    
 | 
			
		||||
                    current_video_length = ui_defaults.get("video_length", 81 if get_model_family(base_model_type)=="wan" else 97)
 | 
			
		||||
 | 
			
		||||
@ -8059,7 +7939,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
                            MMAudio_neg_prompt = gr.Text(ui_defaults.get("MMAudio_neg_prompt", ""), label="Negative Prompt (1 or 2 keywords)")
 | 
			
		||||
                            
 | 
			
		||||
 | 
			
		||||
                    with gr.Column(visible = (t2v or vace) and not fantasy) as audio_prompt_type_remux_row:
 | 
			
		||||
                    with gr.Column(visible = any_control_video) as audio_prompt_type_remux_row:
 | 
			
		||||
                        gr.Markdown("<B>You may transfer the existing audio tracks of a Control Video</B>")
 | 
			
		||||
                        audio_prompt_type_remux = gr.Dropdown(
 | 
			
		||||
                            choices=[
 | 
			
		||||
@ -8284,16 +8164,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
 | 
			
		||||
                        video_info = gr.HTML(visible=True, min_height=100, value=get_default_video_info()) 
 | 
			
		||||
                        with gr.Row(**default_visibility) as video_buttons_row:
 | 
			
		||||
                            video_info_extract_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm")
 | 
			
		||||
                            video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video )
 | 
			
		||||
                            video_info_to_video_source_btn = gr.Button("To Video Source", min_width= 1, size ="sm", visible = any_video_source)
 | 
			
		||||
                            video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video )
 | 
			
		||||
                            video_info_eject_video_btn = gr.Button("Eject Video", min_width= 1, size ="sm")
 | 
			
		||||
                        with gr.Row(**default_visibility) as image_buttons_row:
 | 
			
		||||
                            video_info_extract_image_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm")
 | 
			
		||||
                            video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", min_width= 1, visible = any_start_image )
 | 
			
		||||
                            video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", min_width= 1, visible = any_end_image)
 | 
			
		||||
                            video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image )
 | 
			
		||||
                            video_info_to_image_mask_btn = gr.Button("To Mask Image", min_width= 1, size ="sm", visible = any_image_mask and False)
 | 
			
		||||
                            video_info_to_reference_image_btn = gr.Button("To Reference Image", min_width= 1, size ="sm", visible = any_reference_image)
 | 
			
		||||
                            video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image )
 | 
			
		||||
                            video_info_eject_image_btn = gr.Button("Eject Image", min_width= 1, size ="sm")
 | 
			
		||||
                    with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab:
 | 
			
		||||
                        with gr.Group(elem_classes= "postprocess"):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user